36
36
from modelopt .onnx .op_types import is_fusible_scaling_op
37
37
from modelopt .onnx .quantization .calib_utils import RandomDataProvider
38
38
from modelopt .onnx .quantization .graph_utils import (
39
- _find_quantizable_weights ,
39
+ _find_int4_quantizable_weights as _find_quantizable_weights ,
40
+ )
41
+ from modelopt .onnx .quantization .graph_utils import (
40
42
expand_node_names_from_patterns ,
41
43
get_precision_info ,
42
44
get_tensor_consumer_nodes ,
50
52
find_scales ,
51
53
get_num_bits ,
52
54
quant_tensor ,
55
+ reshape_scales_for_per_channel_nodes ,
53
56
rtn ,
54
57
update_block_size ,
55
- update_scale_map_for_per_channel_nodes ,
56
58
)
57
59
from modelopt .onnx .utils import save_onnx
58
60
@@ -121,6 +123,7 @@ def _quantize_gather_nodes(
121
123
continue
122
124
name = in_tensor .name
123
125
w = in_tensor .values
126
+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
124
127
num_bits = get_num_bits (precision_info , name )
125
128
block_size_updated = update_block_size (
126
129
num_bits , block_size , w = w , quantize_axis = gather_quantize_axis
@@ -170,7 +173,7 @@ def _quantize_gather_nodes(
170
173
)
171
174
else :
172
175
logger .info ("Found 0 Gather nodes to quantize" )
173
- scales_map = update_scale_map_for_per_channel_nodes (scales_map , block_size , precision_info )
176
+ scales_map = reshape_scales_for_per_channel_nodes (scales_map , block_size , precision_info )
174
177
return weights_map , scales_map , zero_point_map
175
178
176
179
@@ -221,6 +224,7 @@ def quantize_rtn(
221
224
precision_info = get_precision_info (onnx_model , nodes_to_exclude , ** kwargs )
222
225
for name , w in gemm_weights .items ():
223
226
logger .debug (f"Computing scales for weight { name } of shape { w .shape } " )
227
+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
224
228
num_bits = get_num_bits (precision_info , name )
225
229
block_size_updated = update_block_size (num_bits , block_size , w = w )
226
230
s , zp = find_scales (np .asarray (w ), block_size_updated , num_bits = num_bits )
@@ -258,14 +262,15 @@ def quantize_rtn(
258
262
gemm_weights_quantized = {}
259
263
for name , w in gemm_weights .items ():
260
264
logger .debug (f"Quantizing weight { name } " )
265
+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
261
266
num_bits = get_num_bits (precision_info , name )
262
267
block_size_updated = update_block_size (num_bits , block_size , w = w )
263
268
qw = rtn (np .asarray (w ), scales [name ], block_size_updated , num_bits = num_bits )
264
269
if has_cupy :
265
270
qw = np .asnumpy (qw )
266
271
scales [name ] = np .asnumpy (scales [name ])
267
272
gemm_weights_quantized [name ] = numpy .asarray (qw )
268
- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
273
+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
269
274
qdq .insert_dq_nodes (
270
275
graph ,
271
276
scales ,
@@ -285,7 +290,7 @@ def quantize_rtn(
285
290
if has_cupy :
286
291
for name in scales :
287
292
scales [name ] = np .asnumpy (scales [name ])
288
- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
293
+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
289
294
qdq .insert_qdq_nodes (graph , scales , weight_map = gemm_tensors , precision_info = precision_info )
290
295
if gather_w_map is not None :
291
296
assert gather_s_map is not None , "scale-map not found for quantizable gather nodes"
@@ -497,6 +502,7 @@ def _quantize_awq_clip(
497
502
w = w .T
498
503
w = np .asarray (w )
499
504
num_bits = get_num_bits (precision_info , weight_tensor .name )
505
+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
500
506
block_size_updated = update_block_size (num_bits , block_size , w = w )
501
507
awq_clip = AWQClipHelper (w , block_size_updated , ** kwargs )
502
508
_clip_search (x , w , awq_clip , num_bits = num_bits , ** kwargs )
@@ -524,7 +530,9 @@ def _quantize_awq_clip(
524
530
525
531
alpha = alphas .get (weight_tensor .name , 1 )
526
532
num_bits = get_num_bits (precision_info , weight_tensor .name )
527
- qw , scale , _ = quant_tensor (w , block_size , alpha = alpha , num_bits = num_bits )
533
+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
534
+ block_size_updated = update_block_size (num_bits , block_size , w = w )
535
+ qw , scale , _ = quant_tensor (w , block_size_updated , alpha = alpha , num_bits = num_bits )
528
536
if has_cupy :
529
537
qw = np .asnumpy (qw )
530
538
scale = np .asnumpy (scale )
@@ -561,7 +569,7 @@ def _quantize_awq_clip(
561
569
562
570
t = time .time ()
563
571
dq_node_attributes = {"axis" : 0 , "block_size" : block_size }
564
- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
572
+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
565
573
qdq .insert_dq_nodes (
566
574
graph_gs ,
567
575
scales ,
@@ -716,6 +724,7 @@ def run_awq_scale_search_per_node(
716
724
x = np .concatenate (output_dicts [act_tensor .name ], axis = 0 ).reshape (
717
725
(- 1 , w .shape [0 ])
718
726
) # n_token, ci
727
+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
719
728
num_bits = get_num_bits (precision_info , weight_tensor .name )
720
729
block_size_updated = update_block_size (num_bits , block_size , w = w )
721
730
awq_lite [i ] = AWQLiteHelper (x , w , block_size_updated , ** kwargs )
@@ -1129,6 +1138,7 @@ def _quantize_awq_lite(
1129
1138
assert enable_weight_clipping or (alpha == 1 ), (
1130
1139
"clip range enabled without enabling weight-clipping param"
1131
1140
)
1141
+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
1132
1142
num_bits = get_num_bits (precision_info , weight_tensor .name )
1133
1143
block_size_updated = update_block_size (num_bits , block_size , w = w_scaled )
1134
1144
qw , scale , zp = quant_tensor (
@@ -1262,7 +1272,7 @@ def _quantize_awq_lite(
1262
1272
1263
1273
t = time .time ()
1264
1274
dq_node_attributes = {"axis" : 0 , "block_size" : block_size }
1265
- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
1275
+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
1266
1276
qdq .insert_dq_nodes (
1267
1277
graph_gs ,
1268
1278
scales ,
@@ -1371,7 +1381,7 @@ def quantize(
1371
1381
Default: 32.
1372
1382
- **enable_mixed_quant** (bool): If True, enable mixed quantization.
1373
1383
Default: False.
1374
- - **int8_layers ** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
1384
+ - **layers_8bit ** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
1375
1385
Default: [].
1376
1386
**Returns**: A quantized ONNX model in ONNX ModelProto format.
1377
1387
"""
0 commit comments