3636from modelopt .onnx .op_types import is_fusible_scaling_op
3737from modelopt .onnx .quantization .calib_utils import RandomDataProvider
3838from modelopt .onnx .quantization .graph_utils import (
39- _find_quantizable_weights ,
39+ _find_int4_quantizable_weights as _find_quantizable_weights ,
40+ )
41+ from modelopt .onnx .quantization .graph_utils import (
4042 expand_node_names_from_patterns ,
4143 get_precision_info ,
4244 get_tensor_consumer_nodes ,
5052 find_scales ,
5153 get_num_bits ,
5254 quant_tensor ,
55+ reshape_scales_for_per_channel_nodes ,
5356 rtn ,
5457 update_block_size ,
55- update_scale_map_for_per_channel_nodes ,
5658)
5759from modelopt .onnx .utils import save_onnx
5860
@@ -121,6 +123,7 @@ def _quantize_gather_nodes(
121123 continue
122124 name = in_tensor .name
123125 w = in_tensor .values
126+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
124127 num_bits = get_num_bits (precision_info , name )
125128 block_size_updated = update_block_size (
126129 num_bits , block_size , w = w , quantize_axis = gather_quantize_axis
@@ -170,7 +173,7 @@ def _quantize_gather_nodes(
170173 )
171174 else :
172175 logger .info ("Found 0 Gather nodes to quantize" )
173- scales_map = update_scale_map_for_per_channel_nodes (scales_map , block_size , precision_info )
176+ scales_map = reshape_scales_for_per_channel_nodes (scales_map , block_size , precision_info )
174177 return weights_map , scales_map , zero_point_map
175178
176179
@@ -221,6 +224,7 @@ def quantize_rtn(
221224 precision_info = get_precision_info (onnx_model , nodes_to_exclude , ** kwargs )
222225 for name , w in gemm_weights .items ():
223226 logger .debug (f"Computing scales for weight { name } of shape { w .shape } " )
227+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
224228 num_bits = get_num_bits (precision_info , name )
225229 block_size_updated = update_block_size (num_bits , block_size , w = w )
226230 s , zp = find_scales (np .asarray (w ), block_size_updated , num_bits = num_bits )
@@ -258,14 +262,15 @@ def quantize_rtn(
258262 gemm_weights_quantized = {}
259263 for name , w in gemm_weights .items ():
260264 logger .debug (f"Quantizing weight { name } " )
265+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
261266 num_bits = get_num_bits (precision_info , name )
262267 block_size_updated = update_block_size (num_bits , block_size , w = w )
263268 qw = rtn (np .asarray (w ), scales [name ], block_size_updated , num_bits = num_bits )
264269 if has_cupy :
265270 qw = np .asnumpy (qw )
266271 scales [name ] = np .asnumpy (scales [name ])
267272 gemm_weights_quantized [name ] = numpy .asarray (qw )
268- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
273+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
269274 qdq .insert_dq_nodes (
270275 graph ,
271276 scales ,
@@ -285,7 +290,7 @@ def quantize_rtn(
285290 if has_cupy :
286291 for name in scales :
287292 scales [name ] = np .asnumpy (scales [name ])
288- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
293+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
289294 qdq .insert_qdq_nodes (graph , scales , weight_map = gemm_tensors , precision_info = precision_info )
290295 if gather_w_map is not None :
291296 assert gather_s_map is not None , "scale-map not found for quantizable gather nodes"
@@ -497,6 +502,7 @@ def _quantize_awq_clip(
497502 w = w .T
498503 w = np .asarray (w )
499504 num_bits = get_num_bits (precision_info , weight_tensor .name )
505+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
500506 block_size_updated = update_block_size (num_bits , block_size , w = w )
501507 awq_clip = AWQClipHelper (w , block_size_updated , ** kwargs )
502508 _clip_search (x , w , awq_clip , num_bits = num_bits , ** kwargs )
@@ -524,7 +530,9 @@ def _quantize_awq_clip(
524530
525531 alpha = alphas .get (weight_tensor .name , 1 )
526532 num_bits = get_num_bits (precision_info , weight_tensor .name )
527- qw , scale , _ = quant_tensor (w , block_size , alpha = alpha , num_bits = num_bits )
533+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
534+ block_size_updated = update_block_size (num_bits , block_size , w = w )
535+ qw , scale , _ = quant_tensor (w , block_size_updated , alpha = alpha , num_bits = num_bits )
528536 if has_cupy :
529537 qw = np .asnumpy (qw )
530538 scale = np .asnumpy (scale )
@@ -561,7 +569,7 @@ def _quantize_awq_clip(
561569
562570 t = time .time ()
563571 dq_node_attributes = {"axis" : 0 , "block_size" : block_size }
564- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
572+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
565573 qdq .insert_dq_nodes (
566574 graph_gs ,
567575 scales ,
@@ -716,6 +724,7 @@ def run_awq_scale_search_per_node(
716724 x = np .concatenate (output_dicts [act_tensor .name ], axis = 0 ).reshape (
717725 (- 1 , w .shape [0 ])
718726 ) # n_token, ci
727+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
719728 num_bits = get_num_bits (precision_info , weight_tensor .name )
720729 block_size_updated = update_block_size (num_bits , block_size , w = w )
721730 awq_lite [i ] = AWQLiteHelper (x , w , block_size_updated , ** kwargs )
@@ -1129,6 +1138,7 @@ def _quantize_awq_lite(
11291138 assert enable_weight_clipping or (alpha == 1 ), (
11301139 "clip range enabled without enabling weight-clipping param"
11311140 )
1141+ # Updating the block size as for 8bit quantization, per-channel quantization is used.
11321142 num_bits = get_num_bits (precision_info , weight_tensor .name )
11331143 block_size_updated = update_block_size (num_bits , block_size , w = w_scaled )
11341144 qw , scale , zp = quant_tensor (
@@ -1262,7 +1272,7 @@ def _quantize_awq_lite(
12621272
12631273 t = time .time ()
12641274 dq_node_attributes = {"axis" : 0 , "block_size" : block_size }
1265- scales = update_scale_map_for_per_channel_nodes (scales , block_size , precision_info )
1275+ scales = reshape_scales_for_per_channel_nodes (scales , block_size , precision_info )
12661276 qdq .insert_dq_nodes (
12671277 graph_gs ,
12681278 scales ,
@@ -1371,7 +1381,7 @@ def quantize(
13711381 Default: 32.
13721382 - **enable_mixed_quant** (bool): If True, enable mixed quantization.
13731383 Default: False.
1374- - **int8_layers ** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
1384+ - **layers_8bit ** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
13751385 Default: [].
13761386 **Returns**: A quantized ONNX model in ONNX ModelProto format.
13771387 """
0 commit comments