From 024f97aeb5e5ffccc0e1b38792601c188c1daafe Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 10 Sep 2025 00:14:07 +0530 Subject: [PATCH 1/2] [5506930]Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models Signed-off-by: unknown --- .../windows/onnx_ptq/genai_llm/quantize.py | 20 +- modelopt/onnx/quantization/int4.py | 377 +++++++++++------- modelopt/onnx/quantization/qdq_utils.py | 72 +++- modelopt/onnx/quantization/quant_utils.py | 184 +++++++++ 4 files changed, 493 insertions(+), 160 deletions(-) diff --git a/examples/windows/onnx_ptq/genai_llm/quantize.py b/examples/windows/onnx_ptq/genai_llm/quantize.py index 3ca97dba..9d469956 100644 --- a/examples/windows/onnx_ptq/genai_llm/quantize.py +++ b/examples/windows/onnx_ptq/genai_llm/quantize.py @@ -365,7 +365,7 @@ def main(args): f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, " f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, " f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, " - f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32}\n" + f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} k_quant_mixed={args.k_quant_mixed}\n" ) print( @@ -435,6 +435,8 @@ def main(args): awqclip_alpha_step=args.awqclip_alpha_step, awqclip_alpha_min=args.awqclip_alpha_min, awqclip_bsz_col=args.awqclip_bsz_col, + k_quant_mixed=args.k_quant_mixed, + int8_layers=args.int8_layers, ) logging.info(f"\nQuantization process took {time.time() - t} seconds") @@ -594,6 +596,20 @@ def main(args): default=False, action="store_true", ) - + parser.add_argument( + "--k_quant_mixed", + default=False, + action="store_true", + help="True when we want to use k_quant_mixed quantization", + ) + parser.add_argument( + "--int8_layers", + type=str, + default="", + help=( + "Comma-separated list of layer patterns to quantize to INT8 instead of INT4." + "Example: 'layers.0,layers.1,lm_head'" + ), + ) args = parser.parse_args() main(args) diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index 33a13d31..16737120 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -19,6 +19,7 @@ import gc import math import os +import re import tempfile import time from collections.abc import Sequence @@ -42,6 +43,7 @@ ) from modelopt.onnx.quantization.gs_patching import patch_gs_modules from modelopt.onnx.quantization.ort_utils import create_inference_session +from modelopt.onnx.quantization.quant_utils import _pad, dq_tensor, find_scales, quant_tensor, rtn from modelopt.onnx.utils import save_onnx __all__ = ["quantize"] @@ -85,125 +87,6 @@ CLIP_MIN = 1e-5 -def _next_block_size_multiple(x: float, block_size: int) -> float: - return math.ceil(x / block_size) * block_size - - -def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: - """Pads `w` to next largest multiple of block_size, on quantize_axis.""" - assert quantize_axis <= len(w.shape), ( - f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" - ) - - if w.shape[quantize_axis] % block_size == 0: - return w - - pad_width = ( - _next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] - ) - pads = [(0, 0) for _ in range(len(w.shape))] - pads[quantize_axis] = (0, pad_width) - return np.pad(w, pads, mode="constant", constant_values=0) - - -def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarray: - """Depad quantize_axis to original shape.""" - if w.shape == orig_shape: - return w - ans = None - if quantize_axis == 0: - ans = w[0 : orig_shape[0], ...] - elif quantize_axis == 1: - ans = w[..., 0 : orig_shape[1]] - else: - raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array") - return ans - - -def find_scales( - w: np.ndarray, - block_size: int, - quantize_axis: int = 0, - alpha: float = 1.0, - use_zero_point: bool = False, -): - """Find scale factors for `w` via `s = max(w.block(block_size)) / 7`.""" - w = _pad(w, block_size, quantize_axis) - if quantize_axis == 0: - w = w.T - s_last_dim = w.shape[-1] // block_size - s_shape = list(w.shape) - s_shape[-1] = s_last_dim - z = None - if not use_zero_point: - w_amax = np.abs(w.reshape(-1, block_size)).max(axis=-1) - s = (w_amax * alpha) / INT4_SCALE - s = s.reshape(s_shape) - else: - max_val = w.reshape(-1, block_size).max(axis=-1) - min_val = w.reshape(-1, block_size).min(axis=-1) - max_int = UINT4_MAX - min_int = UINT4_MIN - s = (max_val - min_val).clip(min=CLIP_MIN) / max_int - # z = -np.round(temp).clip(min=min_int, max=max_int) # gives 0 - need to check - temp = min_val / s - temp = np.round(temp) - temp = -temp - temp = temp.clip(min=min_int, max=max_int) - z = temp - assert s.shape == z.shape, "s and z shape mismatch" - s = s.reshape(s_shape) - z = z.reshape(s_shape) - assert z is None or use_zero_point is True, "zero-point value and use-zero-point not in sync" - if quantize_axis == 0: - s = s.T - if z is not None: - z = z.T - return s, z - - -def rtn( - w: np.ndarray, s: np.ndarray, block_size: int, quantize_axis: int = 0, zp: np.ndarray = None -) -> np.ndarray: - """Quantizes `w` with scale factors `s` via Round-to-Nearest. - - Ties are broken by rounding to the nearest even number. - """ - w_padded = _pad(w, block_size, quantize_axis) - num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] - if zp is None: - w_padded = ( - np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) - .clip(INT4_MIN, INT4_MAX) - .astype(np.int8) - ) - else: - w_padded = ( - ( - np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) - + zp.repeat(num_blocks, axis=quantize_axis) - ) - .clip(UINT4_MIN, UINT4_MAX) - .astype(np.int8) - ) - return _depad(w_padded, w.shape, quantize_axis) - - -def dq_tensor( - w: np.ndarray, s: np.ndarray, block_size: int, quantize_axis: int = 0, zp: np.ndarray = None -) -> np.ndarray: - """Dequantizes `w` with scale factors `s`.""" - w_padded = _pad(w, block_size, quantize_axis) - num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] - if zp is None: - w_padded = w_padded * s.repeat(num_blocks, axis=quantize_axis) - else: - w_padded = (w_padded - zp.repeat(num_blocks, axis=quantize_axis)) * s.repeat( - num_blocks, axis=quantize_axis - ) - return _depad(w_padded, w.shape, quantize_axis) - - def _quantize_gather_nodes( graph: onnx.GraphProto, nodes_to_exclude: list[str], @@ -278,6 +161,7 @@ def quantize_rtn( block_size: int, dq_only: bool = False, nodes_to_exclude: list[str] = [], + precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the RTN (Round-to-Nearest) algorithm. @@ -319,7 +203,7 @@ def quantize_rtn( gemm_io_type = {} for name, w in gemm_weights.items(): logger.debug(f"Computing scales for weight {name} of shape {w.shape}") - s, zp = find_scales(np.asarray(w), block_size) + s, zp = find_scales(np.asarray(w), block_size, precision_info=precision_info, name=name) assert zp is None, "zero-point is not enabled but zp is found non-None" scales[name] = s gemm_io_type[name] = onnx.helper.np_dtype_to_tensor_dtype(cast("int", w.dtype)) @@ -346,48 +230,53 @@ def quantize_rtn( dq_only=dq_only, ) + is_per_channel = block_size == -1 if dq_only: # Calculate actual quantized weights. logger.info("Computing quantized weights for DQ-only mode") gemm_weights_quantized = {} for name, w in gemm_weights.items(): logger.debug(f"Quantizing weight {name}") - qw = rtn(np.asarray(w), scales[name], block_size) + qw = rtn( + np.asarray(w), scales[name], block_size, precision_info=precision_info, name=name + ) if has_cupy: qw = np.asnumpy(qw) scales[name] = np.asnumpy(scales[name]) gemm_weights_quantized[name] = numpy.asarray(qw) - qdq.insert_dq_nodes(graph, scales, quantized_weights=gemm_weights_quantized) + qdq.insert_dq_nodes( + graph, + scales, + quantized_weights=gemm_weights_quantized, + precision_info=precision_info, + is_per_channel=is_per_channel, + ) + if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" - qdq.insert_dq_nodes(graph, gather_s_map, quantized_weights=gather_w_map) + qdq.insert_dq_nodes( + graph, + gather_s_map, + quantized_weights=gather_w_map, + precision_info=precision_info, + is_per_channel=is_per_channel, + ) else: if has_cupy: for name in scales: scales[name] = np.asnumpy(scales[name]) - qdq.insert_qdq_nodes(graph, scales, weight_map=gemm_tensors) + qdq.insert_qdq_nodes(graph, scales, weight_map=gemm_tensors, precision_info=precision_info) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" - qdq.insert_qdq_nodes(graph, gather_s_map, weight_map=gather_w_map) + qdq.insert_qdq_nodes( + graph, gather_s_map, weight_map=gather_w_map, precision_info=precision_info + ) logger.info(f"RTN quantization completed in {time.time() - t_start:.2f} seconds") return gs.export_onnx(graph) -def quant_tensor( - w: np.ndarray, - block_size: int, - quantize_axis: int = 0, - alpha: float = 1.0, - use_zero_point: bool = False, -): - """Quantize a tensor using alpha etc. and return the quantized tensor.""" - scale, zp = find_scales(w, block_size, quantize_axis, alpha, use_zero_point) - wq = rtn(w, scale, block_size, quantize_axis, zp) - return wq, scale, zp - - class AWQClipHelper: """AWQ calibration helper class.""" @@ -424,7 +313,13 @@ def update_best_params(self): def _clip_search( - x: np.ndarray, w: np.ndarray, awq_clip: AWQClipHelper, max_tokens: int = 64, **kwargs + x: np.ndarray, + w: np.ndarray, + awq_clip: AWQClipHelper, + max_tokens: int = 64, + precision_info: dict[str, int] | None = None, + name: str | None = None, + **kwargs, ): """Apply AWQ algorithm on a weight and return optimum alpha. @@ -464,8 +359,10 @@ def _clip_search( # Compute loss for each alpha value for alpha in awq_clip.loss: # Perform QDQ on the whole original weight tensor - qw, scales, _ = quant_tensor(w_copy, block_size, alpha=alpha) - cur_w = dq_tensor(qw, scales, block_size) + qw, scales, _ = quant_tensor( + w_copy, block_size, alpha=alpha, precision_info=precision_info, name=name + ) + cur_w = dq_tensor(qw, scales, block_size, precision_info=precision_info, name=name) # Reshape before getting the batch of size co_bsz to multiply with input cur_w = cur_w.T # ci, co -> co, ci @@ -566,6 +463,7 @@ def _quantize_awq_clip( force_fp16: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, + precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" @@ -625,7 +523,9 @@ def _quantize_awq_clip( w = np.asarray(w) awq_clip = AWQClipHelper(w, block_size, **kwargs) - _clip_search(x, w, awq_clip, **kwargs) + _clip_search( + x, w, awq_clip, precision_info=precision_info, name=weight_tensor.name, **kwargs + ) alphas[weight_tensor.name] = awq_clip.best_alpha logger.info(f"Clip search for all weights took {time.time() - t} seconds") @@ -649,7 +549,9 @@ def _quantize_awq_clip( w = np.asarray(w) alpha = alphas.get(weight_tensor.name, 1) - qw, scale, _ = quant_tensor(w, block_size, alpha=alpha) + qw, scale, _ = quant_tensor( + w, block_size, alpha=alpha, precision_info=precision_info, name=weight_tensor.name + ) if has_cupy: qw = np.asnumpy(qw) scale = np.asnumpy(scale) @@ -685,14 +587,25 @@ def _quantize_awq_clip( t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} + is_per_channel = block_size == -1 qdq.insert_dq_nodes( - graph_gs, scales, quantized_weights=gemm_weights_quantized, attributes=dq_node_attributes + graph_gs, + scales, + quantized_weights=gemm_weights_quantized, + attributes=dq_node_attributes, + precision_info=precision_info, + is_per_channel=is_per_channel, ) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} qdq.insert_dq_nodes( - graph_gs, scales, quantized_weights=gather_w_map, attributes=gather_dq_node_attributes + graph_gs, + scales, + quantized_weights=gather_w_map, + attributes=gather_dq_node_attributes, + precision_info=precision_info, + is_per_channel=is_per_channel, ) logger.info(f"Inserting DQ nodes took {time.time() - t} seconds") @@ -789,6 +702,7 @@ def run_awq_scale_search_per_node( enable_fast_path_using_high_sysram, output_data, clip_alphas, + precision_info: dict[str, int] | None = None, **kwargs: Any, ): """Method that iterates over each quantizable node for scale search.""" @@ -845,8 +759,16 @@ def run_awq_scale_search_per_node( x_scaled = x * 1.0 / awq_scale w_scaled = w * awq_scale[:, np.newaxis] - qw, scale, zp = quant_tensor(w_scaled, block_size, use_zero_point=use_zero_point) - dqw = dq_tensor(qw, scale, block_size, zp=zp) + qw, scale, zp = quant_tensor( + w_scaled, + block_size, + use_zero_point=use_zero_point, + precision_info=precision_info, + name=weight_tensor.name, + ) + dqw = dq_tensor( + qw, scale, block_size, zp=zp, precision_info=precision_info, name=weight_tensor.name + ) out_curr = x_scaled.__matmul__(dqw) loss = np.mean(np.power((out_actual - out_curr), 2)) del out_curr @@ -856,7 +778,9 @@ def run_awq_scale_search_per_node( if enable_weight_clipping: w = w * (awq_lite[i].best_scale[:, np.newaxis]) awq_clip = AWQClipHelper(w, block_size, **kwargs) - _clip_search(x, w, awq_clip, **kwargs) + _clip_search( + x, w, awq_clip, precision_info=precision_info, name=weight_tensor.name, **kwargs + ) clip_alphas[weight_tensor.name] = awq_clip.best_alpha del x, w, out_actual, output_dicts if has_cupy: @@ -955,6 +879,7 @@ def run_awq_scale_search_per_subgraph( awq_lite, inputs, tqdm_msg_append_str, + precision_info: dict[str, int] | None = None, **kwargs: Any, ): """Method that iterates over each quantizable subgraph/siblings for scale search.""" @@ -1010,8 +935,21 @@ def run_awq_scale_search_per_subgraph( assert out_act is not None x_scaled = x * 1.0 / awq_scale w_scaled = w * awq_scale[:, np.newaxis] - qw, scale, zp = quant_tensor(w_scaled, block_size, use_zero_point=use_zero_point) - dqw = dq_tensor(qw, scale, block_size, zp=zp) + qw, scale, zp = quant_tensor( + w_scaled, + block_size, + use_zero_point=use_zero_point, + precision_info=precision_info, + name=weight_tensor.name, + ) + dqw = dq_tensor( + qw, + scale, + block_size, + zp=zp, + precision_info=precision_info, + name=weight_tensor.name, + ) out_curr = x_scaled.__matmul__(dqw) loss += np.mean(np.power((out_act - out_curr), 2)) del out_curr, out_act @@ -1071,6 +1009,7 @@ def _quantize_awq_lite( use_zero_point: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, + precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" @@ -1170,6 +1109,7 @@ def _quantize_awq_lite( awq_lite, inputs, msg, + precision_info, **kwargs, ) else: @@ -1186,6 +1126,7 @@ def _quantize_awq_lite( enable_fast_path_using_high_sysram, output_data, clip_alphas, + precision_info, **kwargs, ) assert len(awq_lite) == len(wa_pack) @@ -1238,8 +1179,14 @@ def _quantize_awq_lite( "clip range enabled without enabling weight-clipping param" ) qw, scale, zp = quant_tensor( - w_scaled, block_size, alpha=alpha, use_zero_point=use_zero_point + w_scaled, + block_size, + alpha=alpha, + use_zero_point=use_zero_point, + precision_info=precision_info, + name=weight_tensor.name, ) + assert use_zero_point is True or zp is None, "zp is not according to use-zero-point setting" if do_transpose: qw = qw.T @@ -1362,12 +1309,14 @@ def _quantize_awq_lite( t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} + qdq.insert_dq_nodes( graph_gs, scales, quantized_weights=gemm_weights_quantized, attributes=dq_node_attributes, zero_points=zero_points if use_zero_point else None, + precision_info=precision_info, ) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" @@ -1381,6 +1330,7 @@ def _quantize_awq_lite( quantized_weights=gather_w_map, attributes=gather_dq_node_attributes, zero_points=gather_zp_map if use_zero_point else None, + precision_info=precision_info, ) if pre_quant_scale: qdq.insert_pre_quant_scale_nodes(graph_gs, input_tensors, pre_quant_scale) @@ -1404,6 +1354,119 @@ def _quantize_awq_lite( return model +def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): + """Check if layer should be quantized to INT8. + + The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. + The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. + + To match these, we: + - Remove the leading slash from the node name. + - Replace all '/' with '.' to match the naming convention of the initializer. + + This allows us to correctly identify which weights should be quantized to INT8. + """ + if not int8_layers: + return False + normalized_patterns = [] + for pattern in int8_layers: + p = pattern.lstrip("/") + p = p.replace("/", ".") + normalized_patterns.append(p) + return any(norm_pattern in layer_name for norm_pattern in normalized_patterns) + + +def get_layer_precision_mapping( + onnx_model: onnx.ModelProto, + int8_precision_pattern: str | None = None, + nodes_to_exclude: list[str] | None = [r"/lm_head"], +): + graph = onnx_model.graph + + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + # Collect quantizable weight tensors + wa_pack = _find_quantizable_weights(graph, nodes_to_exclude) + + if int8_precision_pattern: + int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()] + + else: + matmul_nodes = [ + node + for node in onnx_model.graph.node + if node.op_type == "MatMul" and "lm_head" not in node.name + ] + + # Only include nodes matching the specified patterns for all layers present in the model + # For example, for all i where a node exists with name: + # /model/layers.{i}/attn/qkv_proj/MatMul + # /model/layers.{i}/attn/v_proj/MatMul + # /model/layers.{i}/mlp/down_proj/MatMul + pattern_regexes = [ + re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"), + re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"), + re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"), + ] + + # Filter matmul_nodes to only those matching the patterns + filtered_matmul_nodes = [] + for node in matmul_nodes: + for pat in pattern_regexes: + if pat.match(node.name): + filtered_matmul_nodes.append(node) + break + + # Build a mapping from group key to list of node names (ordered by layer index if possible) + def extract_group_key(node_name): + # Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul + parts = node_name.split("/") + if len(parts) >= 3: + return ".".join(parts[-3:-1]) + return node_name + + group_to_nodes = {} + for node in filtered_matmul_nodes: + group_key = extract_group_key(node.name) + group_to_nodes.setdefault(group_key, []).append(node.name) + + int8_layers_set = set() + for names in group_to_nodes.values(): + n = len(names) + if n == 0: + continue + + # Try to sort by layer index if present + def layer_idx(name): + m = re.search(r"layers\.(\d+)\.", name) + return int(m.group(1)) if m else 0 + + names_sorted = sorted(names, key=layer_idx) + first_eighth = int(n // 8) + last_eighth = int(n // 8) + # First 1/8 + int8_layers_set.update(names_sorted[:first_eighth]) + # Last 1/8 + if last_eighth > 0: + int8_layers_set.update(names_sorted[-last_eighth:]) + # Every third in the rest (excluding first and last eighth) + rest_start = first_eighth + rest_end = n - last_eighth + for i in range(rest_start, rest_end): + if (i - rest_start) % 3 == 0: + int8_layers_set.add(names_sorted[i]) + int8_layers_list = list(int8_layers_set) + + # NEW: Create precision info mapping + precision_info = {} + for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack): + weight_name = weight_tensor.name + if should_quantize_to_int8(weight_name, int8_layers_list): + precision_info[weight_name] = 8 + else: + precision_info[weight_name] = 4 + return precision_info + + def quantize( onnx_path: str | onnx.ModelProto, calibration_method: str = "awq_lite", @@ -1415,6 +1478,8 @@ def quantize( nodes_to_exclude: list[str] | None = [r"/lm_head"], log_level: str = "INFO", input_shapes_profile: Sequence[dict[str, str]] | None = None, + k_quant_mixed: bool = False, + int8_layers: str | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Applies INT4 Weight-Only-Quantization (WoQ) to an ONNX model. @@ -1480,6 +1545,9 @@ def quantize( block_size = 128 logger.info(f"Using default block size: {block_size}") + if block_size == -1: + logger.info("Using per-channel quantization") + # set config params nodes_to_exclude = nodes_to_exclude or [] logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}") @@ -1493,6 +1561,10 @@ def quantize( else: onnx_model = onnx_path + precision_info = None + if k_quant_mixed: + precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude) + # Initialize calibration_data_reader if not provided if calibration_data_reader is None: calibration_data_reader = RandomDataProvider(onnx_model) @@ -1506,6 +1578,7 @@ def quantize( block_size, dq_only="dq" in calibration_method, nodes_to_exclude=nodes_to_exclude, + precision_info=precision_info, **kwargs, ) elif calibration_method in ["awq_lite", "awq_full"]: @@ -1525,6 +1598,7 @@ def quantize( use_zero_point=use_zero_point, enable_weight_clipping=do_weight_clipping, input_shapes_profile=input_shapes_profile, + precision_info=precision_info, **kwargs, ) elif calibration_method in ["awq_clip", "awq_clip_trt"]: @@ -1536,6 +1610,7 @@ def quantize( block_size, nodes_to_exclude=nodes_to_exclude, input_shapes_profile=input_shapes_profile, + precision_info=precision_info, **kwargs, ) else: diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 38ed010c..0a3c54e2 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -54,7 +54,11 @@ "Half": onnx.TensorProto.FLOAT16, "INT8": onnx.TensorProto.INT8, "UINT8": onnx.TensorProto.UINT8, + "INT4": onnx.TensorProto.INT4, + "UINT4": onnx.TensorProto.UINT4, } +onnx_bit_dtype_signed_map = {4: "INT4", 8: "INT8"} +onnx_bit_dtype_unsigned_map = {4: "UINT4", 8: "UINT8"} np_dtype_map = { "Float": np.float32, @@ -303,12 +307,28 @@ def _insert_helper( graph.toposort() +def update_attributes(attrib: dict[str, Any] | None = None, is_per_channel: bool = False): + """Update attribute dictionary for quantization nodes. + + If per-channel quantization is enabled, sets the 'axis' attribute to 1 and removes + the 'block_size' attribute if present. + """ + if is_per_channel: + if attrib is not None: + attrib["axis"] = 1 + if "block_size" in attrib: + attrib.pop("block_size") + return attrib + + def insert_dq_nodes( graph: gs.Graph, scales: dict[str, np.ndarray], quantized_weights: dict[str, np.ndarray], attributes: dict[str, Any] | None = None, zero_points: dict[str, np.ndarray] | None = None, + precision_info: dict[str, int] | None = None, + is_per_channel: bool = False, ): """Insert new initializers and DQ nodes into graph. @@ -325,10 +345,28 @@ def _insert_helper( wq: np.ndarray, scale: np.ndarray, dq_nodes: dict[str, gs.Node], - attrs: dict[str, Any], zp: np.ndarray, + attrs: dict[str, Any] | None = None, + precision_info: dict[str, int] | None = None, + is_per_channel: bool = False, ): - tensor_dtype = onnx.TensorProto.INT4 if zp is None else onnx.TensorProto.UINT4 + attrib = dict(attrs) if attrs is not None else None + if precision_info and name in precision_info: + tensor_dtype = ( + onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]] + if zp is None + else onnx_dtype_map[onnx_bit_dtype_unsigned_map[precision_info[name]]] + ) + # do per-channel quantization for int8 as no support for int8 block-wise dq node + if precision_info[name] == 8: + # reshape scale to be per-channel + scale = scale.reshape(-1) + attrib = update_attributes(attrib, True) + else: + tensor_dtype = onnx.TensorProto.INT4 if zp is None else onnx.TensorProto.UINT4 + + attrib = update_attributes(attrib, is_per_channel) + wq_tensor = make_gs_quantized_weight(name, wq, tensor_dtype) scale_tensor = make_gs_scale(name, scale) dq_out = make_gs_dequantize_output(name, shape=wq.shape, dtype=scale.dtype) @@ -340,7 +378,7 @@ def _insert_helper( name, inputs=inputs, outputs=[dq_out], - attributes=attrs, + attributes=attrib, ) dq_nodes[name] = dq_node @@ -350,7 +388,16 @@ def _insert_helper( if zero_points is not None: zp = zero_points.get(name) assert zp is not None, "zero-point is enabled but zero-point values not found" - _insert_helper(name, quantized_weights[name], scale, dq_nodes, attributes, zp) # type: ignore[arg-type] + _insert_helper( + name, + quantized_weights[name], + scale, + dq_nodes, + zp, + attributes, + precision_info, + is_per_channel, + ) _postprocess_qdq( graph, @@ -363,6 +410,7 @@ def insert_qdq_nodes( graph: gs.Graph, scales: dict[str, np.ndarray], weight_map: dict[str, gs.Tensor], + precision_info: dict[str, int] | None = None, ): """Insert scales and QDQ nodes into graph. @@ -379,10 +427,20 @@ def _insert_helper( scale: np.ndarray, q_nodes: dict[str, gs.Node], dq_nodes: dict[str, gs.Node], + precision_info: dict[str, int] | None = None, ): + if precision_info and name in precision_info: + tensor_dtype = onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]] + # do per-channel quantization for int8 as no support for int8 block-wise dq node + if precision_info[name] == 8: + # reshape scale to be per-channel + scale = scale.reshape(-1) + else: + tensor_dtype = onnx.TensorProto.INT4 + scale_tensor = make_gs_scale(name, scale) - zp_tensor = make_gs_zp(name, scale.shape, onnx.TensorProto.INT4) - q_out = make_gs_quantize_output(name, weight_to_quantize.shape, onnx.TensorProto.INT4) + zp_tensor = make_gs_zp(name, scale.shape, tensor_dtype) + q_out = make_gs_quantize_output(name, weight_to_quantize.shape, tensor_dtype) q_node = make_gs_quantize_node( name, inputs=[weight_to_quantize, scale_tensor, zp_tensor], outputs=[q_out] ) @@ -395,7 +453,7 @@ def _insert_helper( q_nodes, dq_nodes = {}, {} for name, scale in scales.items(): - _insert_helper(name, weight_map[name], scale, q_nodes, dq_nodes) + _insert_helper(name, weight_map[name], scale, q_nodes, dq_nodes, precision_info) _postprocess_qdq( graph, diff --git a/modelopt/onnx/quantization/quant_utils.py b/modelopt/onnx/quantization/quant_utils.py index f108594b..cab7310b 100644 --- a/modelopt/onnx/quantization/quant_utils.py +++ b/modelopt/onnx/quantization/quant_utils.py @@ -15,6 +15,7 @@ """Provides some basic utilities that can be used in quantize() methods.""" +import math from collections.abc import Sequence import numpy as np @@ -23,6 +24,9 @@ INT4_MAX = 7 UINT4_MIN = 0 UINT4_MAX = 15 +# following min-value for clip is taken from AutoAWQ where zero-point based quantization is +# supported and working +CLIP_MIN = 1e-5 def pack_float32_to_4bit_optimized(array: np.ndarray | Sequence, signed: bool) -> np.ndarray: @@ -153,6 +157,186 @@ def get_weights_scaling_factor( return q_per_block_scale.astype(np.float32) +def get_num_bits(precision_info: dict[str, int] | None = None, name: str | None = None) -> int: + """Determine the number of bits for quantization from precision_info.""" + if precision_info and name in precision_info: + num_bits = precision_info[name] + else: + num_bits = 4 + return num_bits + + +def _next_block_size_multiple(x: float, block_size: int) -> float: + return math.ceil(x / block_size) * block_size + + +def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: + """Pads `w` to next largest multiple of block_size, on quantize_axis.""" + assert quantize_axis <= len(w.shape), ( + f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" + ) + + if w.shape[quantize_axis] % block_size == 0: + return w + + pad_width = ( + _next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] + ) + pads = [(0, 0) for _ in range(len(w.shape))] + pads[quantize_axis] = (0, pad_width) + return np.pad(w, pads, mode="constant", constant_values=0) + + +def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarray: + """Depad quantize_axis to original shape.""" + if w.shape == orig_shape: + return w + ans = None + if quantize_axis == 0: + ans = w[0 : orig_shape[0], ...] + elif quantize_axis == 1: + ans = w[..., 0 : orig_shape[1]] + else: + raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array") + return ans + + +def find_scales( + w: np.ndarray, + block_size: int, + quantize_axis: int = 0, + alpha: float = 1.0, + use_zero_point: bool = False, + precision_info: dict[str, int] | None = None, + name: str | None = None, +): + """Find scale factors for `w` via `s = max(w.block(block_size)) / 7`.""" + num_bits = get_num_bits(precision_info, name) + # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, + # set block_size to the size of the quantize_axis dimension to do per-channel quantization + if block_size == -1 or num_bits == 8: + block_size = w.shape[quantize_axis] + w = _pad(w, block_size, quantize_axis) + if quantize_axis == 0: + w = w.T + + s_last_dim = w.shape[-1] // block_size + s_shape = list(w.shape) + s_shape[-1] = s_last_dim + z = None + if not use_zero_point: + scale = 2 ** (num_bits - 1) + w_amax = np.abs(w.reshape(-1, block_size)).max(axis=-1) + s = (w_amax * alpha) / scale + s = s.reshape(s_shape) + else: + max_val = w.reshape(-1, block_size).max(axis=-1) + min_val = w.reshape(-1, block_size).min(axis=-1) + max_int = (2**num_bits) - 1 + min_int = 0 + s = (max_val - min_val).clip(min=CLIP_MIN) / max_int + # z = -np.round(temp).clip(min=min_int, max=max_int) # gives 0 - need to check + temp = min_val / s + temp = np.round(temp) + temp = -temp + temp = temp.clip(min=min_int, max=max_int) + z = temp + assert s.shape == z.shape, "s and z shape mismatch" + s = s.reshape(s_shape) + z = z.reshape(s_shape) + assert z is None or use_zero_point is True, "zero-point value and use-zero-point not in sync" + if quantize_axis == 0: + s = s.T + if z is not None: + z = z.T + return s, z + + +def rtn( + w: np.ndarray, + s: np.ndarray, + block_size: int, + quantize_axis: int = 0, + zp: np.ndarray = None, + precision_info: dict[str, int] | None = None, + name: str | None = None, +) -> np.ndarray: + """Quantizes `w` with scale factors `s` via Round-to-Nearest. + + Ties are broken by rounding to the nearest even number. + """ + num_bits = get_num_bits(precision_info, name) + # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, + # set block_size to the size of the quantize_axis dimension to do per-channel quantization + if block_size == -1 or num_bits == 8: + block_size = w.shape[quantize_axis] + w_padded = _pad(w, block_size, quantize_axis) + num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] + if zp is None: + maxq = 2 ** (num_bits - 1) - 1 + minq = -(2 ** (num_bits - 1)) + w_padded = ( + np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) + .clip(minq, maxq) + .astype(np.int8) + ) + else: + maxq = (2**num_bits) - 1 + minq = 0 + w_padded = ( + ( + np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) + + zp.repeat(num_blocks, axis=quantize_axis) + ) + .clip(minq, maxq) + .astype(np.int8) + ) + return _depad(w_padded, w.shape, quantize_axis) + + +def dq_tensor( + w: np.ndarray, + s: np.ndarray, + block_size: int, + quantize_axis: int = 0, + zp: np.ndarray = None, + precision_info: dict[str, int] | None = None, + name: str | None = None, +) -> np.ndarray: + """Dequantizes `w` with scale factors `s`.""" + num_bits = get_num_bits(precision_info, name) + # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, + # set block_size to the size of the quantize_axis dimension to do per-channel quantization + if block_size == -1 or num_bits == 8: + block_size = w.shape[quantize_axis] + w_padded = _pad(w, block_size, quantize_axis) + num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] + if zp is None: + w_padded = w_padded * s.repeat(num_blocks, axis=quantize_axis) + else: + w_padded = (w_padded - zp.repeat(num_blocks, axis=quantize_axis)) * s.repeat( + num_blocks, axis=quantize_axis + ) + return _depad(w_padded, w.shape, quantize_axis) + + +def quant_tensor( + w: np.ndarray, + block_size: int, + quantize_axis: int = 0, + alpha: float = 1.0, + use_zero_point: bool = False, + precision_info: dict[str, int] | None = None, + name: str | None = None, +): + """Quantize a tensor using alpha etc. and return the quantized tensor.""" + scale, zp = find_scales( + w, block_size, quantize_axis, alpha, use_zero_point, precision_info, name + ) + wq = rtn(w, scale, block_size, quantize_axis, zp, precision_info, name) + return wq, scale, zp + + def quantize( input: np.ndarray, block_size: int, From b6a39be25c45b1fa415005ab2f1494e446c08e8f Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 11 Sep 2025 23:42:03 +0530 Subject: [PATCH 2/2] [5506930]Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models, refactored changes and handle comments Signed-off-by: unknown --- examples/windows/onnx_ptq/genai_llm/README.md | 1 + .../windows/onnx_ptq/genai_llm/quantize.py | 8 +- modelopt/onnx/quantization/graph_utils.py | 218 ++++++++++++ modelopt/onnx/quantization/int4.py | 309 ++++-------------- modelopt/onnx/quantization/qdq_utils.py | 79 ++--- modelopt/onnx/quantization/quant_utils.py | 89 +++-- 6 files changed, 390 insertions(+), 314 deletions(-) diff --git a/examples/windows/onnx_ptq/genai_llm/README.md b/examples/windows/onnx_ptq/genai_llm/README.md index 5cfbedda..dea6c954 100644 --- a/examples/windows/onnx_ptq/genai_llm/README.md +++ b/examples/windows/onnx_ptq/genai_llm/README.md @@ -56,6 +56,7 @@ The table below lists key command-line arguments of the ONNX PTQ example script. | `--awqclip_bsz_col` | 1024 (default) | Chunk size in columns during weight clipping, user-defined | | `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [dml,cpu]) | List of execution-providers to use for session run during calibration | | `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data| +| `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization| Run the following command to view all available parameters in the script: diff --git a/examples/windows/onnx_ptq/genai_llm/quantize.py b/examples/windows/onnx_ptq/genai_llm/quantize.py index 9d469956..57505e73 100644 --- a/examples/windows/onnx_ptq/genai_llm/quantize.py +++ b/examples/windows/onnx_ptq/genai_llm/quantize.py @@ -365,7 +365,7 @@ def main(args): f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, " f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, " f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, " - f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} k_quant_mixed={args.k_quant_mixed}\n" + f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}\n" ) print( @@ -435,7 +435,7 @@ def main(args): awqclip_alpha_step=args.awqclip_alpha_step, awqclip_alpha_min=args.awqclip_alpha_min, awqclip_bsz_col=args.awqclip_bsz_col, - k_quant_mixed=args.k_quant_mixed, + enable_mixed_quant=args.enable_mixed_quant, int8_layers=args.int8_layers, ) logging.info(f"\nQuantization process took {time.time() - t} seconds") @@ -597,10 +597,10 @@ def main(args): action="store_true", ) parser.add_argument( - "--k_quant_mixed", + "--enable_mixed_quant", default=False, action="store_true", - help="True when we want to use k_quant_mixed quantization", + help="True when we want to use mixed quantization", ) parser.add_argument( "--int8_layers", diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index ef18e181..3a926b42 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -18,6 +18,7 @@ import re from collections import defaultdict from functools import reduce +from typing import Any, cast import numpy as np import onnx @@ -625,6 +626,223 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None) return nodes_to_exclude +def _find_quantizable_weights( + graph: onnx.GraphProto, + nodes_to_exclude: list[str], +) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]: + """Finds the quantizable weights from the graph.""" + wa_pack = [] + gemm_nodes = [ + node + for node in graph.node + if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude + ] + initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)} + for gemm in gemm_nodes: + if gemm.input[0] in initializer_idxs: + # Ex. two const input to MatMul_115 in fastvit0.onnx + # Note. RTN algorithm will quantize these weights though + continue + + if gemm.input[1] not in initializer_idxs: + continue + + weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]] + if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported + continue + + gemm_io_type = cast("int", weight_tensor.data_type) + + act_tensor = onnx.helper.ValueInfoProto() + act_tensor.name = gemm.input[0] + + # TODO: support transA by transposing activation tensors in _clip_search + do_transpose = gemm.op_type == "Gemm" and any( + attr.name == "transB" and attr.i > 0 for attr in gemm.attribute + ) + + wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type)) + + return wa_pack + + +def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): + """Check if layer should be quantized to INT8. + + The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. + The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. + + To match these, we: + - Remove the leading slash from the node name. + - Replace all '/' with '.' to match the naming convention of the initializer. + + This allows us to correctly identify which weights should be quantized to INT8. + """ + if not int8_layers: + return False + + # Normalize both to dot-delimited tokens and require exact token sequence match. + def tokens(s: str) -> list[str]: + return s.lstrip("/").replace("/", ".").split(".") + + hay = tokens(layer_name) + for pat in int8_layers: + needle = tokens(pat) + n, m = len(hay), len(needle) + for i in range(n - m + 1): + if hay[i : i + m] == needle: + return True + return False + + +def validate_int8_layers(layers_str: str) -> bool: + """Validate the format of int8_layers string.""" + if not layers_str: + return True + # Basic validation: check for valid characters and structure + import re + + pattern = r"^[a-zA-Z0-9_.,\-]$" + return bool(re.match(pattern, layers_str)) + + +def get_layer_precision_mapping( + onnx_model: onnx.ModelProto, + int8_precision_pattern: str | None = None, + nodes_to_exclude: list[str] | None = [r"/lm_head"], +): + """Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model. + + Args: + onnx_model (onnx.ModelProto): The ONNX model to analyze. + int8_precision_pattern (str, optional): Comma-separated string of layer patterns to quantize to INT8. + If None, a default set of patterns is used to select layers for INT8 quantization. + nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization. + Defaults to [r"/lm_head"]. + + Returns: + dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "int8"}). + """ + graph = onnx_model.graph + + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + # Collect quantizable weight tensors + wa_pack = _find_quantizable_weights(graph, nodes_to_exclude) + + if int8_precision_pattern: + if not validate_int8_layers(int8_precision_pattern): + raise ValueError("Invalid format for --int8_layers. Use comma-separated layers.") + int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()] + + else: + matmul_nodes = [ + node + for node in onnx_model.graph.node + if node.op_type == "MatMul" and "lm_head" not in node.name + ] + + # Only include nodes matching the specified patterns for all layers present in the model + # For example, for all i where a node exists with name: + # /model/layers.{i}/attn/qkv_proj/MatMul + # /model/layers.{i}/attn/v_proj/MatMul + # /model/layers.{i}/mlp/down_proj/MatMul + pattern_regexes = [ + re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"), + re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"), + re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"), + ] + + # Filter matmul_nodes to only those matching the patterns + filtered_matmul_nodes = [] + for node in matmul_nodes: + for pat in pattern_regexes: + if pat.match(node.name): + filtered_matmul_nodes.append(node) + break + + # Build a mapping from group key to list of node names (ordered by layer index if possible) + def extract_group_key(node_name): + # Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul + parts = node_name.split("/") + if len(parts) >= 3: + return ".".join(parts[-3:-1]) + return node_name + + group_to_nodes = {} + for node in filtered_matmul_nodes: + group_key = extract_group_key(node.name) + group_to_nodes.setdefault(group_key, []).append(node.name) + + int8_layers_set = set() + for names in group_to_nodes.values(): + n = len(names) + if n == 0: + continue + + # Try to sort by layer index if present + def layer_idx(name): + m = re.search(r"layers\.(\d+)\.", name) + return int(m.group(1)) if m else 0 + + names_sorted = sorted(names, key=layer_idx) + first_eighth = int(n // 8) + last_eighth = int(n // 8) + # First 1/8 + int8_layers_set.update(names_sorted[:first_eighth]) + # Last 1/8 + if last_eighth > 0: + int8_layers_set.update(names_sorted[-last_eighth:]) + # Every third in the rest (excluding first and last eighth) + rest_start = first_eighth + rest_end = n - last_eighth + for i in range(rest_start, rest_end): + if (i - rest_start) % 3 == 0: + int8_layers_set.add(names_sorted[i]) + int8_layers_list = list(int8_layers_set) + + # NEW: Create precision info mapping + precision_info = {} + for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack): + weight_name = weight_tensor.name + if should_quantize_to_int8(weight_name, int8_layers_list): + precision_info[weight_name] = 8 + else: + precision_info[weight_name] = 4 + return precision_info + + +def get_precision_info( + onnx_model: onnx.ModelProto, + nodes_to_exclude: list[str] | None = [r"/lm_head"], + **kwargs: Any, +): + """Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits). + + This function determines the quantization precision for each weight tensor in the ONNX model, + based on the provided configuration. If mixed quantization is enabled, it uses the layer + precision mapping; otherwise, it returns None. + + Args: + onnx_model (onnx.ModelProto): The ONNX model to analyze. + nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization. + **kwargs: Additional keyword arguments, such as: + - enable_mixed_quant (bool): Whether to enable mixed quantization. + - int8_layers (str): Comma-separated list of layer patterns to quantize to INT8. + + Returns: + dict[str, int] | None: A mapping from weight tensor names to their quantization precision, + or None if mixed quantization is not enabled. + """ + precision_info = None + enable_mixed_quant = kwargs.get("enable_mixed_quant", False) + int8_layers = kwargs.get("int8_layers") + if enable_mixed_quant: + precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude) + else: + precision_info = None + return precision_info + + def expand_node_names_from_patterns( graph: onnx.GraphProto | Graph, name_patterns: list[str] | None = None ) -> list[str]: diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index 16737120..acd7bdb1 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -19,7 +19,6 @@ import gc import math import os -import re import tempfile import time from collections.abc import Sequence @@ -37,13 +36,24 @@ from modelopt.onnx.op_types import is_fusible_scaling_op from modelopt.onnx.quantization.calib_utils import RandomDataProvider from modelopt.onnx.quantization.graph_utils import ( + _find_quantizable_weights, expand_node_names_from_patterns, + get_precision_info, get_tensor_consumer_nodes, get_tensor_producer_nodes, ) from modelopt.onnx.quantization.gs_patching import patch_gs_modules from modelopt.onnx.quantization.ort_utils import create_inference_session -from modelopt.onnx.quantization.quant_utils import _pad, dq_tensor, find_scales, quant_tensor, rtn +from modelopt.onnx.quantization.quant_utils import ( + _pad, + dq_tensor, + find_scales, + get_num_bits, + quant_tensor, + rtn, + update_block_size, + update_scale_map_for_per_channel_nodes, +) from modelopt.onnx.utils import save_onnx __all__ = ["quantize"] @@ -94,6 +104,7 @@ def _quantize_gather_nodes( block_size: int, use_zero_point: bool, dq_only: bool, + precision_info: dict[str, int] | None, ): """Return scale, zero-point, and weights for quantizable gather nodes using INT4 RTN.""" t = time.time() @@ -110,11 +121,16 @@ def _quantize_gather_nodes( continue name = in_tensor.name w = in_tensor.values + num_bits = get_num_bits(precision_info, name) + block_size_updated = update_block_size( + num_bits, block_size, w=w, quantize_axis=gather_quantize_axis + ) s, zp = find_scales( np.asarray(w), - block_size, + block_size_updated, quantize_axis=gather_quantize_axis, use_zero_point=use_zero_point, + num_bits=num_bits, ) s = s.astype(w.dtype) scales_map[name] = s @@ -130,9 +146,10 @@ def _quantize_gather_nodes( qw = rtn( np.asarray(w), s, - block_size, + block_size_updated, quantize_axis=gather_quantize_axis, zp=zp if zp is None else zp.astype(w.dtype), + num_bits=num_bits, ) weights_map[name] = qw.astype(weight_dtype) else: @@ -153,6 +170,7 @@ def _quantize_gather_nodes( ) else: logger.info("Found 0 Gather nodes to quantize") + scales_map = update_scale_map_for_per_channel_nodes(scales_map, block_size, precision_info) return weights_map, scales_map, zero_point_map @@ -161,7 +179,6 @@ def quantize_rtn( block_size: int, dq_only: bool = False, nodes_to_exclude: list[str] = [], - precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the RTN (Round-to-Nearest) algorithm. @@ -201,9 +218,13 @@ def quantize_rtn( logger.info("Computing scales for gemm weights") scales = {} gemm_io_type = {} + precision_info = get_precision_info(onnx_model, nodes_to_exclude, **kwargs) for name, w in gemm_weights.items(): logger.debug(f"Computing scales for weight {name} of shape {w.shape}") - s, zp = find_scales(np.asarray(w), block_size, precision_info=precision_info, name=name) + num_bits = get_num_bits(precision_info, name) + block_size_updated = update_block_size(num_bits, block_size, w=w) + s, zp = find_scales(np.asarray(w), block_size_updated, num_bits=num_bits) + assert zp is None, "zero-point is not enabled but zp is found non-None" scales[name] = s gemm_io_type[name] = onnx.helper.np_dtype_to_tensor_dtype(cast("int", w.dtype)) @@ -228,29 +249,28 @@ def quantize_rtn( gather_block_size, use_zero_point=False, dq_only=dq_only, + precision_info=precision_info, ) - is_per_channel = block_size == -1 if dq_only: # Calculate actual quantized weights. logger.info("Computing quantized weights for DQ-only mode") gemm_weights_quantized = {} for name, w in gemm_weights.items(): logger.debug(f"Quantizing weight {name}") - qw = rtn( - np.asarray(w), scales[name], block_size, precision_info=precision_info, name=name - ) + num_bits = get_num_bits(precision_info, name) + block_size_updated = update_block_size(num_bits, block_size, w=w) + qw = rtn(np.asarray(w), scales[name], block_size_updated, num_bits=num_bits) if has_cupy: qw = np.asnumpy(qw) scales[name] = np.asnumpy(scales[name]) gemm_weights_quantized[name] = numpy.asarray(qw) - + scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info) qdq.insert_dq_nodes( graph, scales, quantized_weights=gemm_weights_quantized, precision_info=precision_info, - is_per_channel=is_per_channel, ) if gather_w_map is not None: @@ -260,12 +280,12 @@ def quantize_rtn( gather_s_map, quantized_weights=gather_w_map, precision_info=precision_info, - is_per_channel=is_per_channel, ) else: if has_cupy: for name in scales: scales[name] = np.asnumpy(scales[name]) + scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info) qdq.insert_qdq_nodes(graph, scales, weight_map=gemm_tensors, precision_info=precision_info) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" @@ -317,8 +337,7 @@ def _clip_search( w: np.ndarray, awq_clip: AWQClipHelper, max_tokens: int = 64, - precision_info: dict[str, int] | None = None, - name: str | None = None, + num_bits: int = 4, **kwargs, ): """Apply AWQ algorithm on a weight and return optimum alpha. @@ -359,11 +378,8 @@ def _clip_search( # Compute loss for each alpha value for alpha in awq_clip.loss: # Perform QDQ on the whole original weight tensor - qw, scales, _ = quant_tensor( - w_copy, block_size, alpha=alpha, precision_info=precision_info, name=name - ) - cur_w = dq_tensor(qw, scales, block_size, precision_info=precision_info, name=name) - + qw, scales, _ = quant_tensor(w_copy, block_size, alpha=alpha, num_bits=num_bits) + cur_w = dq_tensor(qw, scales, block_size) # Reshape before getting the batch of size co_bsz to multiply with input cur_w = cur_w.T # ci, co -> co, ci cur_w = cur_w.reshape(co, 1, -1, block_size) # co, 1, n_block, block_size @@ -383,46 +399,6 @@ def _clip_search( np.get_default_memory_pool().free_all_blocks() -def _find_quantizable_weights( - graph: onnx.GraphProto, - nodes_to_exclude: list[str], -) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]: - """Finds the quantizable weights from the graph.""" - wa_pack = [] - gemm_nodes = [ - node - for node in graph.node - if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude - ] - initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)} - for gemm in gemm_nodes: - if gemm.input[0] in initializer_idxs: - # Ex. two const input to MatMul_115 in fastvit0.onnx - # Note. RTN algorithm will quantize these weights though - continue - - if gemm.input[1] not in initializer_idxs: - continue - - weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]] - if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported - continue - - gemm_io_type = cast("int", weight_tensor.data_type) - - act_tensor = onnx.helper.ValueInfoProto() - act_tensor.name = gemm.input[0] - - # TODO: support transA by transposing activation tensors in _clip_search - do_transpose = gemm.op_type == "Gemm" and any( - attr.name == "transB" and attr.i > 0 for attr in gemm.attribute - ) - - wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type)) - - return wa_pack - - def _augment_graph( graph: onnx.GraphProto, wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int]], @@ -463,7 +439,6 @@ def _quantize_awq_clip( force_fp16: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, - precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" @@ -498,7 +473,7 @@ def _quantize_awq_clip( for inp_d in data_reader: inputs.append(inp_d) assert isinstance(inp_d, dict) - + precision_info = get_precision_info(onnx_model, nodes_to_exclude, **kwargs) # Apply AWQ clip on selected weights t = time.time() alphas = {} @@ -521,11 +496,10 @@ def _quantize_awq_clip( if do_transpose: w = w.T w = np.asarray(w) - - awq_clip = AWQClipHelper(w, block_size, **kwargs) - _clip_search( - x, w, awq_clip, precision_info=precision_info, name=weight_tensor.name, **kwargs - ) + num_bits = get_num_bits(precision_info, weight_tensor.name) + block_size_updated = update_block_size(num_bits, block_size, w=w) + awq_clip = AWQClipHelper(w, block_size_updated, **kwargs) + _clip_search(x, w, awq_clip, num_bits=num_bits, **kwargs) alphas[weight_tensor.name] = awq_clip.best_alpha logger.info(f"Clip search for all weights took {time.time() - t} seconds") @@ -549,9 +523,8 @@ def _quantize_awq_clip( w = np.asarray(w) alpha = alphas.get(weight_tensor.name, 1) - qw, scale, _ = quant_tensor( - w, block_size, alpha=alpha, precision_info=precision_info, name=weight_tensor.name - ) + num_bits = get_num_bits(precision_info, weight_tensor.name) + qw, scale, _ = quant_tensor(w, block_size, alpha=alpha, num_bits=num_bits) if has_cupy: qw = np.asnumpy(qw) scale = np.asnumpy(scale) @@ -577,35 +550,34 @@ def _quantize_awq_clip( gather_s_map = None if gather_quantize_axis is not None: gather_w_map, gather_s_map, _ = _quantize_gather_nodes( - graph, + graph_gs, nodes_to_exclude, gather_quantize_axis, gather_block_size, use_zero_point=False, dq_only=True, + precision_info=precision_info, ) t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} - is_per_channel = block_size == -1 + scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info) qdq.insert_dq_nodes( graph_gs, scales, quantized_weights=gemm_weights_quantized, attributes=dq_node_attributes, precision_info=precision_info, - is_per_channel=is_per_channel, ) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" gather_dq_node_attributes = {"axis": gather_quantize_axis, "block_size": gather_block_size} qdq.insert_dq_nodes( graph_gs, - scales, + gather_s_map, quantized_weights=gather_w_map, attributes=gather_dq_node_attributes, precision_info=precision_info, - is_per_channel=is_per_channel, ) logger.info(f"Inserting DQ nodes took {time.time() - t} seconds") @@ -707,7 +679,6 @@ def run_awq_scale_search_per_node( ): """Method that iterates over each quantizable node for scale search.""" assert len(awq_lite) == len(wa_pack) - for i in tqdm( range(len(wa_pack)), desc="Running AWQ scale search per node" + tqdm_msg_append_str, @@ -745,8 +716,9 @@ def run_awq_scale_search_per_node( x = np.concatenate(output_dicts[act_tensor.name], axis=0).reshape( (-1, w.shape[0]) ) # n_token, ci - - awq_lite[i] = AWQLiteHelper(x, w, block_size, **kwargs) + num_bits = get_num_bits(precision_info, weight_tensor.name) + block_size_updated = update_block_size(num_bits, block_size, w=w) + awq_lite[i] = AWQLiteHelper(x, w, block_size_updated, **kwargs) out_actual = x.__matmul__(w) @@ -758,17 +730,13 @@ def run_awq_scale_search_per_node( ) x_scaled = x * 1.0 / awq_scale w_scaled = w * awq_scale[:, np.newaxis] - qw, scale, zp = quant_tensor( w_scaled, - block_size, + block_size_updated, use_zero_point=use_zero_point, - precision_info=precision_info, - name=weight_tensor.name, - ) - dqw = dq_tensor( - qw, scale, block_size, zp=zp, precision_info=precision_info, name=weight_tensor.name + num_bits=num_bits, ) + dqw = dq_tensor(qw, scale, block_size_updated, zp=zp) out_curr = x_scaled.__matmul__(dqw) loss = np.mean(np.power((out_actual - out_curr), 2)) del out_curr @@ -777,10 +745,8 @@ def run_awq_scale_search_per_node( awq_lite[i].update_best_params() if enable_weight_clipping: w = w * (awq_lite[i].best_scale[:, np.newaxis]) - awq_clip = AWQClipHelper(w, block_size, **kwargs) - _clip_search( - x, w, awq_clip, precision_info=precision_info, name=weight_tensor.name, **kwargs - ) + awq_clip = AWQClipHelper(w, block_size_updated, **kwargs) + _clip_search(x, w, awq_clip, num_bits=num_bits, **kwargs) clip_alphas[weight_tensor.name] = awq_clip.best_alpha del x, w, out_actual, output_dicts if has_cupy: @@ -847,7 +813,6 @@ def get_x_w_mean_for_subgraph( np.get_default_memory_pool().free_all_blocks() assert w_concatenated is not None - org_shape = w_concatenated.shape w_concatenated = w_concatenated.reshape(block_size, -1) div_by = np.amax(np.abs(w_concatenated), axis=0) @@ -879,7 +844,6 @@ def run_awq_scale_search_per_subgraph( awq_lite, inputs, tqdm_msg_append_str, - precision_info: dict[str, int] | None = None, **kwargs: Any, ): """Method that iterates over each quantizable subgraph/siblings for scale search.""" @@ -935,21 +899,8 @@ def run_awq_scale_search_per_subgraph( assert out_act is not None x_scaled = x * 1.0 / awq_scale w_scaled = w * awq_scale[:, np.newaxis] - qw, scale, zp = quant_tensor( - w_scaled, - block_size, - use_zero_point=use_zero_point, - precision_info=precision_info, - name=weight_tensor.name, - ) - dqw = dq_tensor( - qw, - scale, - block_size, - zp=zp, - precision_info=precision_info, - name=weight_tensor.name, - ) + qw, scale, zp = quant_tensor(w_scaled, block_size, use_zero_point=use_zero_point) + dqw = dq_tensor(qw, scale, block_size, zp=zp) out_curr = x_scaled.__matmul__(dqw) loss += np.mean(np.power((out_act - out_curr), 2)) del out_curr, out_act @@ -961,7 +912,6 @@ def run_awq_scale_search_per_subgraph( best_error = loss best_alpha = alpha best_scale = awq_scale - for wa_pack_idx in wa_pack_idx_list: assert np.isnan(best_scale).sum() == 0, best_scale assert awq_lite[wa_pack_idx] is None @@ -1009,13 +959,12 @@ def _quantize_awq_lite( use_zero_point: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, - precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" logger.info("Quantizing model using AWQ lite algorithm") t = time.time() - + precision_info = get_precision_info(onnx_model, nodes_to_exclude, **kwargs) run_per_subgraph = kwargs.get("awqlite_run_per_subgraph", False) fuse_nodes = kwargs.get("awqlite_fuse_nodes", True) @@ -1025,6 +974,9 @@ def _quantize_awq_lite( # TODO - evaluate/add sysram based fast-path support in per-subgraph implementation assert not run_per_subgraph or not enable_fast_path_using_high_sysram + # TODO - add support for handling awq_lite mixed precision for per-subgraph implementation + assert not run_per_subgraph or precision_info is None + augmented_model = copy.deepcopy(onnx_model) graph = augmented_model.graph @@ -1096,8 +1048,8 @@ def _quantize_awq_lite( act_to_wa_pack_map, act_to_quant_nodes_weight_shape_map = ( get_act_to_weight_map_and_act_to_wa_pack_map(wa_pack) ) - if run_per_subgraph: + # TODO - add support for handling awq_lite mixed precision for per-subgraph implementation awq_lite = run_awq_scale_search_per_subgraph( wa_pack, act_to_wa_pack_map, @@ -1109,7 +1061,6 @@ def _quantize_awq_lite( awq_lite, inputs, msg, - precision_info, **kwargs, ) else: @@ -1178,13 +1129,14 @@ def _quantize_awq_lite( assert enable_weight_clipping or (alpha == 1), ( "clip range enabled without enabling weight-clipping param" ) + num_bits = get_num_bits(precision_info, weight_tensor.name) + block_size_updated = update_block_size(num_bits, block_size, w=w_scaled) qw, scale, zp = quant_tensor( w_scaled, - block_size, + block_size_updated, alpha=alpha, use_zero_point=use_zero_point, - precision_info=precision_info, - name=weight_tensor.name, + num_bits=num_bits, ) assert use_zero_point is True or zp is None, "zp is not according to use-zero-point setting" @@ -1305,11 +1257,12 @@ def _quantize_awq_lite( gather_block_size, use_zero_point=use_zero_point, dq_only=True, + precision_info=precision_info, ) t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} - + scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info) qdq.insert_dq_nodes( graph_gs, scales, @@ -1354,119 +1307,6 @@ def _quantize_awq_lite( return model -def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): - """Check if layer should be quantized to INT8. - - The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. - The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. - - To match these, we: - - Remove the leading slash from the node name. - - Replace all '/' with '.' to match the naming convention of the initializer. - - This allows us to correctly identify which weights should be quantized to INT8. - """ - if not int8_layers: - return False - normalized_patterns = [] - for pattern in int8_layers: - p = pattern.lstrip("/") - p = p.replace("/", ".") - normalized_patterns.append(p) - return any(norm_pattern in layer_name for norm_pattern in normalized_patterns) - - -def get_layer_precision_mapping( - onnx_model: onnx.ModelProto, - int8_precision_pattern: str | None = None, - nodes_to_exclude: list[str] | None = [r"/lm_head"], -): - graph = onnx_model.graph - - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) - # Collect quantizable weight tensors - wa_pack = _find_quantizable_weights(graph, nodes_to_exclude) - - if int8_precision_pattern: - int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()] - - else: - matmul_nodes = [ - node - for node in onnx_model.graph.node - if node.op_type == "MatMul" and "lm_head" not in node.name - ] - - # Only include nodes matching the specified patterns for all layers present in the model - # For example, for all i where a node exists with name: - # /model/layers.{i}/attn/qkv_proj/MatMul - # /model/layers.{i}/attn/v_proj/MatMul - # /model/layers.{i}/mlp/down_proj/MatMul - pattern_regexes = [ - re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"), - re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"), - re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"), - ] - - # Filter matmul_nodes to only those matching the patterns - filtered_matmul_nodes = [] - for node in matmul_nodes: - for pat in pattern_regexes: - if pat.match(node.name): - filtered_matmul_nodes.append(node) - break - - # Build a mapping from group key to list of node names (ordered by layer index if possible) - def extract_group_key(node_name): - # Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul - parts = node_name.split("/") - if len(parts) >= 3: - return ".".join(parts[-3:-1]) - return node_name - - group_to_nodes = {} - for node in filtered_matmul_nodes: - group_key = extract_group_key(node.name) - group_to_nodes.setdefault(group_key, []).append(node.name) - - int8_layers_set = set() - for names in group_to_nodes.values(): - n = len(names) - if n == 0: - continue - - # Try to sort by layer index if present - def layer_idx(name): - m = re.search(r"layers\.(\d+)\.", name) - return int(m.group(1)) if m else 0 - - names_sorted = sorted(names, key=layer_idx) - first_eighth = int(n // 8) - last_eighth = int(n // 8) - # First 1/8 - int8_layers_set.update(names_sorted[:first_eighth]) - # Last 1/8 - if last_eighth > 0: - int8_layers_set.update(names_sorted[-last_eighth:]) - # Every third in the rest (excluding first and last eighth) - rest_start = first_eighth - rest_end = n - last_eighth - for i in range(rest_start, rest_end): - if (i - rest_start) % 3 == 0: - int8_layers_set.add(names_sorted[i]) - int8_layers_list = list(int8_layers_set) - - # NEW: Create precision info mapping - precision_info = {} - for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack): - weight_name = weight_tensor.name - if should_quantize_to_int8(weight_name, int8_layers_list): - precision_info[weight_name] = 8 - else: - precision_info[weight_name] = 4 - return precision_info - - def quantize( onnx_path: str | onnx.ModelProto, calibration_method: str = "awq_lite", @@ -1478,8 +1318,6 @@ def quantize( nodes_to_exclude: list[str] | None = [r"/lm_head"], log_level: str = "INFO", input_shapes_profile: Sequence[dict[str, str]] | None = None, - k_quant_mixed: bool = False, - int8_layers: str | None = None, **kwargs: Any, ) -> onnx.ModelProto: """Applies INT4 Weight-Only-Quantization (WoQ) to an ONNX model. @@ -1531,6 +1369,10 @@ def quantize( Default: None (Gather nodes not quantized). - **gather_block_size** (int): Block-size for Gather nodes quantization. Default: 32. + - **enable_mixed_quant** (bool): If True, enable mixed quantization. + Default: False. + - **int8_layers** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4. + Default: []. **Returns**: A quantized ONNX model in ONNX ModelProto format. """ configure_logging(level=log_level.upper()) @@ -1561,10 +1403,6 @@ def quantize( else: onnx_model = onnx_path - precision_info = None - if k_quant_mixed: - precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude) - # Initialize calibration_data_reader if not provided if calibration_data_reader is None: calibration_data_reader = RandomDataProvider(onnx_model) @@ -1578,7 +1416,6 @@ def quantize( block_size, dq_only="dq" in calibration_method, nodes_to_exclude=nodes_to_exclude, - precision_info=precision_info, **kwargs, ) elif calibration_method in ["awq_lite", "awq_full"]: @@ -1598,7 +1435,6 @@ def quantize( use_zero_point=use_zero_point, enable_weight_clipping=do_weight_clipping, input_shapes_profile=input_shapes_profile, - precision_info=precision_info, **kwargs, ) elif calibration_method in ["awq_clip", "awq_clip_trt"]: @@ -1610,7 +1446,6 @@ def quantize( block_size, nodes_to_exclude=nodes_to_exclude, input_shapes_profile=input_shapes_profile, - precision_info=precision_info, **kwargs, ) else: diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 0a3c54e2..284d7497 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -35,6 +35,7 @@ from modelopt.onnx.quantization.quant_utils import ( compute_e8m0, get_amax, + get_num_bits, get_weights_scaling_factor, get_weights_scaling_factor_2, pack_weights_to_int4, @@ -307,18 +308,20 @@ def _insert_helper( graph.toposort() -def update_attributes(attrib: dict[str, Any] | None = None, is_per_channel: bool = False): - """Update attribute dictionary for quantization nodes. +def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: + """Get the appropriate tensor dtype based on precision info and zero point presence. - If per-channel quantization is enabled, sets the 'axis' attribute to 1 and removes - the 'block_size' attribute if present. + Args: + num_bits: Number of bits for quantization + has_zero_point: Whether the tensor has a zero point + Returns: + ONNX tensor data type constant """ - if is_per_channel: - if attrib is not None: - attrib["axis"] = 1 - if "block_size" in attrib: - attrib.pop("block_size") - return attrib + if has_zero_point: + dtype_str = onnx_bit_dtype_unsigned_map[num_bits] + else: + dtype_str = onnx_bit_dtype_signed_map[num_bits] + return onnx_dtype_map[dtype_str] def insert_dq_nodes( @@ -328,7 +331,6 @@ def insert_dq_nodes( attributes: dict[str, Any] | None = None, zero_points: dict[str, np.ndarray] | None = None, precision_info: dict[str, int] | None = None, - is_per_channel: bool = False, ): """Insert new initializers and DQ nodes into graph. @@ -347,25 +349,9 @@ def _insert_helper( dq_nodes: dict[str, gs.Node], zp: np.ndarray, attrs: dict[str, Any] | None = None, - precision_info: dict[str, int] | None = None, - is_per_channel: bool = False, + num_bits: int = 4, ): - attrib = dict(attrs) if attrs is not None else None - if precision_info and name in precision_info: - tensor_dtype = ( - onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]] - if zp is None - else onnx_dtype_map[onnx_bit_dtype_unsigned_map[precision_info[name]]] - ) - # do per-channel quantization for int8 as no support for int8 block-wise dq node - if precision_info[name] == 8: - # reshape scale to be per-channel - scale = scale.reshape(-1) - attrib = update_attributes(attrib, True) - else: - tensor_dtype = onnx.TensorProto.INT4 if zp is None else onnx.TensorProto.UINT4 - - attrib = update_attributes(attrib, is_per_channel) + tensor_dtype = get_tensor_dtype(num_bits, zp is not None) wq_tensor = make_gs_quantized_weight(name, wq, tensor_dtype) scale_tensor = make_gs_scale(name, scale) @@ -378,7 +364,7 @@ def _insert_helper( name, inputs=inputs, outputs=[dq_out], - attributes=attrib, + attributes=attrs, ) dq_nodes[name] = dq_node @@ -388,15 +374,22 @@ def _insert_helper( if zero_points is not None: zp = zero_points.get(name) assert zp is not None, "zero-point is enabled but zero-point values not found" + + num_bits = get_num_bits(precision_info, name) + attrs = attributes.copy() if attributes is not None else None + if ((attrs is not None) and (attrs.get("block_size", None) == -1)) or (num_bits == 8): + if attrs is not None: + attrs["axis"] = 1 + if "block_size" in attrs: + del attrs["block_size"] _insert_helper( name, quantized_weights[name], scale, dq_nodes, zp, - attributes, - precision_info, - is_per_channel, + attrs, + num_bits=num_bits, ) _postprocess_qdq( @@ -427,16 +420,9 @@ def _insert_helper( scale: np.ndarray, q_nodes: dict[str, gs.Node], dq_nodes: dict[str, gs.Node], - precision_info: dict[str, int] | None = None, + num_bits: int = 4, ): - if precision_info and name in precision_info: - tensor_dtype = onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]] - # do per-channel quantization for int8 as no support for int8 block-wise dq node - if precision_info[name] == 8: - # reshape scale to be per-channel - scale = scale.reshape(-1) - else: - tensor_dtype = onnx.TensorProto.INT4 + tensor_dtype = get_tensor_dtype(num_bits) scale_tensor = make_gs_scale(name, scale) zp_tensor = make_gs_zp(name, scale.shape, tensor_dtype) @@ -453,7 +439,14 @@ def _insert_helper( q_nodes, dq_nodes = {}, {} for name, scale in scales.items(): - _insert_helper(name, weight_map[name], scale, q_nodes, dq_nodes, precision_info) + _insert_helper( + name, + weight_map[name], + scale, + q_nodes, + dq_nodes, + num_bits=get_num_bits(precision_info, name), + ) _postprocess_qdq( graph, diff --git a/modelopt/onnx/quantization/quant_utils.py b/modelopt/onnx/quantization/quant_utils.py index cab7310b..1ee503a0 100644 --- a/modelopt/onnx/quantization/quant_utils.py +++ b/modelopt/onnx/quantization/quant_utils.py @@ -157,8 +157,35 @@ def get_weights_scaling_factor( return q_per_block_scale.astype(np.float32) +def update_block_size( + num_bits: int, block_size: int, quantize_axis: int = 0, w: np.ndarray = None +) -> int: + """Update the block size for quantization. + + Args: + num_bits (int): Number of bits for quantization. + block_size (int): Current block size. If -1, per-channel quantization is used. + quantize_axis (int): Axis along which to quantize. + w (np.ndarray): Weight tensor to be quantized. + + Returns: + int: Updated block size. + """ + if block_size is not None and (block_size == -1 or num_bits == 8): + return w.shape[quantize_axis] + return block_size + + def get_num_bits(precision_info: dict[str, int] | None = None, name: str | None = None) -> int: - """Determine the number of bits for quantization from precision_info.""" + """Determine the number of bits for quantization from precision_info. + + Args: + precision_info (dict[str, int] | None): Optional dictionary mapping tensor names to number of bits. + name (str | None): Name of the tensor. + + Returns: + int: Number of bits to use for quantization. Defaults to 4 if not specified. + """ if precision_info and name in precision_info: num_bits = precision_info[name] else: @@ -201,21 +228,26 @@ def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarr return ans +def update_scale_map_for_per_channel_nodes( + scales_map: dict[str, np.ndarray], block_size: int, precision_info: dict[str, int] | None = None +): + """Update the scale map for per-channel nodes.""" + for name in scales_map: + num_bits = get_num_bits(precision_info, name) + is_per_channel = (block_size == -1) or (num_bits == 8) + scales_map[name] = scales_map[name].reshape(-1) if is_per_channel else scales_map[name] + return scales_map + + def find_scales( w: np.ndarray, block_size: int, quantize_axis: int = 0, alpha: float = 1.0, use_zero_point: bool = False, - precision_info: dict[str, int] | None = None, - name: str | None = None, + num_bits: int = 4, ): """Find scale factors for `w` via `s = max(w.block(block_size)) / 7`.""" - num_bits = get_num_bits(precision_info, name) - # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, - # set block_size to the size of the quantize_axis dimension to do per-channel quantization - if block_size == -1 or num_bits == 8: - block_size = w.shape[quantize_axis] w = _pad(w, block_size, quantize_axis) if quantize_axis == 0: w = w.T @@ -225,7 +257,7 @@ def find_scales( s_shape[-1] = s_last_dim z = None if not use_zero_point: - scale = 2 ** (num_bits - 1) + scale = 2 ** (num_bits - 1) - 1 w_amax = np.abs(w.reshape(-1, block_size)).max(axis=-1) s = (w_amax * alpha) / scale s = s.reshape(s_shape) @@ -241,6 +273,11 @@ def find_scales( temp = -temp temp = temp.clip(min=min_int, max=max_int) z = temp + # Validate zero-point values are within expected range + if not np.all((z >= min_int) & (z <= max_int)): + raise ValueError( + f"Zero-point values out of range [{min_int}, {max_int}]: min={np.min(z)}, max={np.max(z)}" + ) assert s.shape == z.shape, "s and z shape mismatch" s = s.reshape(s_shape) z = z.reshape(s_shape) @@ -258,18 +295,12 @@ def rtn( block_size: int, quantize_axis: int = 0, zp: np.ndarray = None, - precision_info: dict[str, int] | None = None, - name: str | None = None, + num_bits: int = 4, ) -> np.ndarray: """Quantizes `w` with scale factors `s` via Round-to-Nearest. Ties are broken by rounding to the nearest even number. """ - num_bits = get_num_bits(precision_info, name) - # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, - # set block_size to the size of the quantize_axis dimension to do per-channel quantization - if block_size == -1 or num_bits == 8: - block_size = w.shape[quantize_axis] w_padded = _pad(w, block_size, quantize_axis) num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] if zp is None: @@ -300,15 +331,8 @@ def dq_tensor( block_size: int, quantize_axis: int = 0, zp: np.ndarray = None, - precision_info: dict[str, int] | None = None, - name: str | None = None, ) -> np.ndarray: """Dequantizes `w` with scale factors `s`.""" - num_bits = get_num_bits(precision_info, name) - # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, - # set block_size to the size of the quantize_axis dimension to do per-channel quantization - if block_size == -1 or num_bits == 8: - block_size = w.shape[quantize_axis] w_padded = _pad(w, block_size, quantize_axis) num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] if zp is None: @@ -326,14 +350,19 @@ def quant_tensor( quantize_axis: int = 0, alpha: float = 1.0, use_zero_point: bool = False, - precision_info: dict[str, int] | None = None, - name: str | None = None, + num_bits: int = 4, ): - """Quantize a tensor using alpha etc. and return the quantized tensor.""" - scale, zp = find_scales( - w, block_size, quantize_axis, alpha, use_zero_point, precision_info, name - ) - wq = rtn(w, scale, block_size, quantize_axis, zp, precision_info, name) + """Quantize a tensor using alpha etc. and return the quantized tensor. + + Returns: + tuple: A tuple containing: + - wq: The quantized weight tensor (np.ndarray) + - scale: The scale factors used for quantization (np.ndarray) + - zp: The zero-point values (np.ndarray or None if not using zero-point) + """ + block_size_updated = update_block_size(num_bits, block_size, w=w) + scale, zp = find_scales(w, block_size_updated, quantize_axis, alpha, use_zero_point, num_bits) + wq = rtn(w, scale, block_size_updated, quantize_axis, zp, num_bits) return wq, scale, zp