From a5d2dbe1e2665c0791e32cc7b0ef36102ebb28a3 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 14 Oct 2025 21:20:27 +0530 Subject: [PATCH 1/2] cleanup mixed precision and gather node layer info mapping Signed-off-by: unknown --- modelopt/onnx/quantization/graph_utils.py | 73 +++++++++++---- modelopt/onnx/quantization/int4.py | 107 ++++++++++++---------- modelopt/onnx/quantization/qdq_utils.py | 12 ++- modelopt/onnx/quantization/quant_utils.py | 88 +++++++++++++++--- 4 files changed, 198 insertions(+), 82 deletions(-) diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 6b37e3e7e..781ca27e9 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -40,6 +40,9 @@ save_onnx, ) +DEFAULT_GATHER_BLOCK_SIZE = 32 +DEFAULT_GATHER_QUANTIZE_AXIS = None + def is_const_input(tensor: Tensor) -> bool: """Returns whether the given tensor is an initializer or produced by const-foldable nodes.""" @@ -717,6 +720,8 @@ def get_layer_precision_mapping( onnx_model: onnx.ModelProto, precision_pattern_8bit: str | None = None, nodes_to_exclude: list[str] | None = [r"/lm_head"], + block_size: int = 128, + quantize_axis: int = 0, ): """Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model. @@ -745,7 +750,7 @@ def get_layer_precision_mapping( matmul_nodes = [ node for node in onnx_model.graph.node - if node.op_type == "MatMul" and "lm_head" not in node.name + if node.op_type in ["Gemm", "MatMul"] and "lm_head" not in node.name ] # Only include nodes matching the specified patterns for all layers present in the model @@ -807,27 +812,38 @@ def layer_idx(name): layers_8bit_set.add(names_sorted[i]) layers_list_8bit = list(layers_8bit_set) - # NEW: Create precision info mapping - precision_info = {} + # NEW: Create layer info mapping with precision, block_size, and axis + layer_info = {} for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack): weight_name = weight_tensor.name if should_quantize_to_8bit(weight_name, layers_list_8bit): - precision_info[weight_name] = 8 + layer_info[weight_name] = { + "precision": 8, + "block_size": -1, # Per-channel for 8-bit + "axis": 0, + } else: - precision_info[weight_name] = 4 - return precision_info + layer_info[weight_name] = { + "precision": 4, + "block_size": block_size, # Default block size for 4-bit + "axis": quantize_axis, + } + + return layer_info -def get_precision_info( +def get_layer_info( onnx_model: onnx.ModelProto, nodes_to_exclude: list[str] | None = [r"/lm_head"], + block_size: int = 128, + quantize_axis: int = 0, **kwargs: Any, ): - """Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits). + """Generate a mapping of weight tensor names to their quantization configuration. - This function determines the quantization precision for each weight tensor in the ONNX model, - based on the provided configuration. If mixed quantization is enabled, it uses the layer - precision mapping; otherwise, it returns None. + This function determines the quantization configuration (precision, block_size, axis) for each + weight tensor in the ONNX model, based on the provided configuration. If mixed quantization + is enabled, it uses the layer precision mapping; otherwise, it returns None. Args: onnx_model (onnx.ModelProto): The ONNX model to analyze. @@ -835,19 +851,42 @@ def get_precision_info( **kwargs: Additional keyword arguments, such as: - enable_mixed_quant (bool): Whether to enable mixed quantization. - layers_8bit (str): Comma-separated list of layer patterns to quantize to 8 bit. + - block_size (int): Default block size for quantization. + - quantize_axis (int): Default quantization axis. + - gather_block_size (int): Default block size for gather quantization. + - gather_quantize_axis (int): Default quantization axis for gather. Returns: - dict[str, int] | None: A mapping from weight tensor names to their quantization precision, - or None if mixed quantization is not enabled. + dict[str, dict[str, Any]] | None: A mapping from weight tensor names to their quantization + configuration (with keys: precision, block_size, axis), or None if mixed quantization is not enabled. """ - precision_info = None + layer_info = None enable_mixed_quant = kwargs.get("enable_mixed_quant", False) layers_8bit = kwargs.get("layers_8bit") + gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) + gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) if enable_mixed_quant: - precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude) + layer_info = get_layer_precision_mapping( + onnx_model, + layers_8bit, + nodes_to_exclude, + block_size, + quantize_axis, + ) else: - precision_info = None - return precision_info + layer_info = None + + if gather_quantize_axis is not None: + if layer_info is None: + layer_info = {} + for node in onnx_model.graph.node: + if node.op_type == "Gather": + layer_info[node.input[0]] = { + "precision": 4, + "block_size": gather_block_size, + "axis": gather_quantize_axis, + } + return layer_info def expand_node_names_from_patterns( diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index 1086b5a4d..cfb3c1828 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -40,7 +40,7 @@ ) from modelopt.onnx.quantization.graph_utils import ( expand_node_names_from_patterns, - get_precision_info, + get_layer_info, get_tensor_consumer_nodes, get_tensor_producer_nodes, ) @@ -102,11 +102,9 @@ def _quantize_gather_nodes( graph: onnx.GraphProto, nodes_to_exclude: list[str], - gather_quantize_axis: int, - block_size: int, use_zero_point: bool, dq_only: bool, - precision_info: dict[str, int] | None, + layer_info: dict[str, dict] | None, ): """Return scale, zero-point, and weights for quantizable gather nodes using INT4 RTN.""" t = time.time() @@ -122,11 +120,18 @@ def _quantize_gather_nodes( # 1D blocked quantization not supported. continue name = in_tensor.name + # Get layer-specific settings from layer_info if available + if layer_info and name in layer_info: + gather_quantize_axis = layer_info[name]["axis"] + block_size = layer_info[name].get("block_size", DEFAULT_GATHER_BLOCK_SIZE) + else: + gather_quantize_axis = 0 + block_size = DEFAULT_GATHER_BLOCK_SIZE w = in_tensor.values # Updating the block size as for 8bit quantization, per-channel quantization is used. - num_bits = get_num_bits(precision_info, name) + num_bits = get_num_bits(layer_info, name) block_size_updated = update_block_size( - num_bits, block_size, w=w, quantize_axis=gather_quantize_axis + block_size, layer_info, name, w=w, quantize_axis=gather_quantize_axis ) s, zp = find_scales( np.asarray(w), @@ -173,7 +178,7 @@ def _quantize_gather_nodes( ) else: logger.info("Found 0 Gather nodes to quantize") - scales_map = reshape_scales_for_per_channel_nodes(scales_map, block_size, precision_info) + scales_map = reshape_scales_for_per_channel_nodes(scales_map, block_size, layer_info) return weights_map, scales_map, zero_point_map @@ -221,12 +226,12 @@ def quantize_rtn( logger.info("Computing scales for gemm weights") scales = {} gemm_io_type = {} - precision_info = get_precision_info(onnx_model, nodes_to_exclude, **kwargs) + layer_info = get_layer_info(onnx_model, nodes_to_exclude, block_size, **kwargs) for name, w in gemm_weights.items(): logger.debug(f"Computing scales for weight {name} of shape {w.shape}") # Updating the block size as for 8bit quantization, per-channel quantization is used. - num_bits = get_num_bits(precision_info, name) - block_size_updated = update_block_size(num_bits, block_size, w=w) + num_bits = get_num_bits(layer_info, name) + block_size_updated = update_block_size(block_size, layer_info, name, w=w) s, zp = find_scales(np.asarray(w), block_size_updated, num_bits=num_bits) assert zp is None, "zero-point is not enabled but zp is found non-None" @@ -240,7 +245,6 @@ def quantize_rtn( # Import the update graph graph = gs.import_onnx(onnx_model) - gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) gather_w_map = None @@ -249,11 +253,9 @@ def quantize_rtn( gather_w_map, gather_s_map, _ = _quantize_gather_nodes( graph, nodes_to_exclude, - gather_quantize_axis, - gather_block_size, use_zero_point=False, dq_only=dq_only, - precision_info=precision_info, + layer_info=layer_info, ) if dq_only: @@ -263,43 +265,52 @@ def quantize_rtn( for name, w in gemm_weights.items(): logger.debug(f"Quantizing weight {name}") # Updating the block size as for 8bit quantization, per-channel quantization is used. - num_bits = get_num_bits(precision_info, name) - block_size_updated = update_block_size(num_bits, block_size, w=w) + num_bits = get_num_bits(layer_info, name) + block_size_updated = update_block_size(block_size, layer_info, name, w=w) qw = rtn(np.asarray(w), scales[name], block_size_updated, num_bits=num_bits) if has_cupy: qw = np.asnumpy(qw) scales[name] = np.asnumpy(scales[name]) gemm_weights_quantized[name] = numpy.asarray(qw) - scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info) + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) + dq_node_attributes = {"axis": 0, "block_size": block_size} qdq.insert_dq_nodes( graph, scales, quantized_weights=gemm_weights_quantized, - precision_info=precision_info, + attributes=dq_node_attributes, + layer_info=layer_info, ) if gather_w_map is not None: + gather_dq_node_attributes = { + "axis": gather_quantize_axis, + "block_size": gather_block_size, + } assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" qdq.insert_dq_nodes( graph, gather_s_map, quantized_weights=gather_w_map, - precision_info=precision_info, + attributes=gather_dq_node_attributes, + layer_info=layer_info, ) else: if has_cupy: for name in scales: scales[name] = np.asnumpy(scales[name]) - scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info) - qdq.insert_qdq_nodes(graph, scales, weight_map=gemm_tensors, precision_info=precision_info) + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) + qdq.insert_qdq_nodes(graph, scales, weight_map=gemm_tensors, layer_info=layer_info) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" qdq.insert_qdq_nodes( - graph, gather_s_map, weight_map=gather_w_map, precision_info=precision_info + graph, gather_s_map, weight_map=gather_w_map, layer_info=layer_info ) logger.info(f"RTN quantization completed in {time.time() - t_start:.2f} seconds") - return gs.export_onnx(graph) + model = gs.export_onnx(graph) + model.ir_version = 10 + return model class AWQClipHelper: @@ -478,7 +489,7 @@ def _quantize_awq_clip( for inp_d in data_reader: inputs.append(inp_d) assert isinstance(inp_d, dict) - precision_info = get_precision_info(onnx_model, nodes_to_exclude, **kwargs) + layer_info = get_layer_info(onnx_model, nodes_to_exclude, block_size, **kwargs) # Apply AWQ clip on selected weights t = time.time() alphas = {} @@ -501,9 +512,9 @@ def _quantize_awq_clip( if do_transpose: w = w.T w = np.asarray(w) - num_bits = get_num_bits(precision_info, weight_tensor.name) + num_bits = get_num_bits(layer_info, weight_tensor.name) # Updating the block size as for 8bit quantization, per-channel quantization is used. - block_size_updated = update_block_size(num_bits, block_size, w=w) + block_size_updated = update_block_size(block_size, layer_info, weight_tensor.name, w=w) awq_clip = AWQClipHelper(w, block_size_updated, **kwargs) _clip_search(x, w, awq_clip, num_bits=num_bits, **kwargs) alphas[weight_tensor.name] = awq_clip.best_alpha @@ -529,9 +540,9 @@ def _quantize_awq_clip( w = np.asarray(w) alpha = alphas.get(weight_tensor.name, 1) - num_bits = get_num_bits(precision_info, weight_tensor.name) + num_bits = get_num_bits(layer_info, weight_tensor.name) # Updating the block size as for 8bit quantization, per-channel quantization is used. - block_size_updated = update_block_size(num_bits, block_size, w=w) + block_size_updated = update_block_size(block_size, layer_info, weight_tensor.name, w=w) qw, scale, _ = quant_tensor(w, block_size_updated, alpha=alpha, num_bits=num_bits) if has_cupy: qw = np.asnumpy(qw) @@ -560,22 +571,20 @@ def _quantize_awq_clip( gather_w_map, gather_s_map, _ = _quantize_gather_nodes( graph_gs, nodes_to_exclude, - gather_quantize_axis, - gather_block_size, use_zero_point=False, dq_only=True, - precision_info=precision_info, + layer_info=layer_info, ) t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} - scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info) + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, scales, quantized_weights=gemm_weights_quantized, attributes=dq_node_attributes, - precision_info=precision_info, + layer_info=layer_info, ) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" @@ -585,7 +594,7 @@ def _quantize_awq_clip( gather_s_map, quantized_weights=gather_w_map, attributes=gather_dq_node_attributes, - precision_info=precision_info, + layer_info=layer_info, ) logger.info(f"Inserting DQ nodes took {time.time() - t} seconds") @@ -682,7 +691,7 @@ def run_awq_scale_search_per_node( enable_fast_path_using_high_sysram, output_data, clip_alphas, - precision_info: dict[str, int] | None = None, + layer_info: dict[str, dict] | None = None, **kwargs: Any, ): """Method that iterates over each quantizable node for scale search.""" @@ -725,8 +734,8 @@ def run_awq_scale_search_per_node( (-1, w.shape[0]) ) # n_token, ci # Updating the block size as for 8bit quantization, per-channel quantization is used. - num_bits = get_num_bits(precision_info, weight_tensor.name) - block_size_updated = update_block_size(num_bits, block_size, w=w) + num_bits = get_num_bits(layer_info, weight_tensor.name) + block_size_updated = update_block_size(block_size, layer_info, weight_tensor.name, w=w) awq_lite[i] = AWQLiteHelper(x, w, block_size_updated, **kwargs) out_actual = x.__matmul__(w) @@ -973,7 +982,7 @@ def _quantize_awq_lite( """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" logger.info("Quantizing model using AWQ lite algorithm") t = time.time() - precision_info = get_precision_info(onnx_model, nodes_to_exclude, **kwargs) + layer_info = get_layer_info(onnx_model, nodes_to_exclude, **kwargs) run_per_subgraph = kwargs.get("awqlite_run_per_subgraph", False) fuse_nodes = kwargs.get("awqlite_fuse_nodes", True) @@ -982,9 +991,9 @@ def _quantize_awq_lite( # TODO - evaluate/add sysram based fast-path support in per-subgraph implementation assert not run_per_subgraph or not enable_fast_path_using_high_sysram - + enable_mixed_quant = kwargs.get("enable_mixed_quant", False) # TODO - add support for handling awq_lite mixed precision for per-subgraph implementation - assert not run_per_subgraph or precision_info is None + assert not run_per_subgraph or not enable_mixed_quant augmented_model = copy.deepcopy(onnx_model) graph = augmented_model.graph @@ -1086,7 +1095,7 @@ def _quantize_awq_lite( enable_fast_path_using_high_sysram, output_data, clip_alphas, - precision_info, + layer_info, **kwargs, ) assert len(awq_lite) == len(wa_pack) @@ -1139,8 +1148,10 @@ def _quantize_awq_lite( "clip range enabled without enabling weight-clipping param" ) # Updating the block size as for 8bit quantization, per-channel quantization is used. - num_bits = get_num_bits(precision_info, weight_tensor.name) - block_size_updated = update_block_size(num_bits, block_size, w=w_scaled) + num_bits = get_num_bits(layer_info, weight_tensor.name) + block_size_updated = update_block_size( + block_size, layer_info, weight_tensor.name, w=w_scaled + ) qw, scale, zp = quant_tensor( w_scaled, block_size_updated, @@ -1263,23 +1274,21 @@ def _quantize_awq_lite( gather_w_map, gather_s_map, gather_zp_map = _quantize_gather_nodes( graph_gs, nodes_to_exclude, - gather_quantize_axis, - gather_block_size, use_zero_point=use_zero_point, dq_only=True, - precision_info=precision_info, + layer_info=layer_info, ) t = time.time() dq_node_attributes = {"axis": 0, "block_size": block_size} - scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info) + scales = reshape_scales_for_per_channel_nodes(scales, block_size, layer_info) qdq.insert_dq_nodes( graph_gs, scales, quantized_weights=gemm_weights_quantized, attributes=dq_node_attributes, zero_points=zero_points if use_zero_point else None, - precision_info=precision_info, + layer_info=layer_info, ) if gather_w_map is not None: assert gather_s_map is not None, "scale-map not found for quantizable gather nodes" @@ -1293,7 +1302,7 @@ def _quantize_awq_lite( quantized_weights=gather_w_map, attributes=gather_dq_node_attributes, zero_points=gather_zp_map if use_zero_point else None, - precision_info=precision_info, + layer_info=layer_info, ) if pre_quant_scale: qdq.insert_pre_quant_scale_nodes(graph_gs, input_tensors, pre_quant_scale) diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index c4dbdcc4a..c3edc967f 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -352,7 +352,7 @@ def insert_dq_nodes( quantized_weights: dict[str, np.ndarray], attributes: dict[str, Any] | None = None, zero_points: dict[str, np.ndarray] | None = None, - precision_info: dict[str, int] | None = None, + layer_info: dict[str, dict] | None = None, ): """Insert new initializers and DQ nodes into graph. @@ -361,6 +361,8 @@ def insert_dq_nodes( weights: A map from ONNX initializer name to tensor. scales: A map from ONNX initializer name to desired scale factor for that initializer. dq_only: Whether to only insert dq nodes. + layer_info: Optional dictionary mapping tensor names to precision (old format) or + to layer configuration dict (new format with precision, block_size, axis). """ logger.debug(f"Inserting DQ nodes for {len(scales)} weights") @@ -397,7 +399,7 @@ def _insert_helper( zp = zero_points.get(name) assert zp is not None, "zero-point is enabled but zero-point values not found" - num_bits = get_num_bits(precision_info, name) + num_bits = get_num_bits(layer_info, name) # Updating the attributes for per-channel nodes. attrs = attributes.copy() if attributes is not None else None attrs = update_attributes_for_per_channel_nodes(attrs, num_bits) @@ -423,7 +425,7 @@ def insert_qdq_nodes( graph: gs.Graph, scales: dict[str, np.ndarray], weight_map: dict[str, gs.Tensor], - precision_info: dict[str, int] | None = None, + layer_info: dict[str, dict] | None = None, ): """Insert scales and QDQ nodes into graph. @@ -431,6 +433,8 @@ def insert_qdq_nodes( graph: The graph to modify. scales: A map from ONNX initializer name to desired scale factor for that initializer. weight_map: A map from ONNX initializer name to graphsurgeon tensor. + layer_info: Optional dictionary mapping tensor names to precision (old format) or + to layer configuration dict (new format with precision, block_size, axis). """ logger.debug(f"Inserting QDQ nodes for {len(scales)} weights") @@ -465,7 +469,7 @@ def _insert_helper( scale, q_nodes, dq_nodes, - num_bits=get_num_bits(precision_info, name), + num_bits=get_num_bits(layer_info, name), ) _postprocess_qdq( diff --git a/modelopt/onnx/quantization/quant_utils.py b/modelopt/onnx/quantization/quant_utils.py index 1624c0f1d..ac13cea2e 100644 --- a/modelopt/onnx/quantization/quant_utils.py +++ b/modelopt/onnx/quantization/quant_utils.py @@ -158,41 +158,92 @@ def get_weights_scaling_factor( def update_block_size( - num_bits: int, block_size: int, quantize_axis: int = 0, w: np.ndarray = None + block_size: int, + layer_info: dict[str, dict] | None = None, + name: str | None = None, + quantize_axis: int = 0, + w: np.ndarray = None, ) -> int: """Update the block size for quantization. Args: num_bits (int): Number of bits for quantization. - block_size (int): Current block size. If -1, per-channel quantization is used. + layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names + to layer configuration dict. + name (str | None): Name of the tensor. + block_size (int): Current block size. quantize_axis (int): Axis along which to quantize. w (np.ndarray): Weight tensor to be quantized. Returns: int: Updated block size. """ - if block_size is not None and (block_size == -1 or num_bits == 8): + if layer_info and name in layer_info: + block_size = layer_info[name]["block_size"] + quantize_axis = layer_info[name]["axis"] + if block_size is not None and block_size == -1: return w.shape[quantize_axis] return block_size -def get_num_bits(precision_info: dict[str, int] | None = None, name: str | None = None) -> int: - """Determine the number of bits for quantization from precision_info. +def get_num_bits(layer_info: dict[str, dict] | None = None, name: str | None = None) -> int: + """Determine the layer configuration for quantization from layer_info. Args: - precision_info (dict[str, int] | None): Optional dictionary mapping tensor names to number of bits. + layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names + to layer configuration dict. name (str | None): Name of the tensor. Returns: int: Number of bits to use for quantization. Defaults to 4 if not specified. """ - if precision_info and name in precision_info: - num_bits = precision_info[name] + if layer_info and name in layer_info: + num_bits = layer_info[name]["precision"] else: num_bits = 4 return num_bits +def get_layer_block_size( + layer_info: dict[str, dict] | None = None, + name: str | None = None, + default_block_size: int | None = None, +) -> int | None: + """Get the block size for a specific layer from layer_info. + + Args: + layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names to layer configuration. + name (str | None): Name of the tensor. + default_block_size (int | None): Default block size if not specified. Defaults to None. + + Returns: + int: Block size to use for quantization. + """ + if layer_info and name in layer_info: + return layer_info[name].get("block_size", default_block_size) + return default_block_size + + +def get_layer_axis( + layer_info: dict[str, dict] | None = None, + name: str | None = None, + default_axis: int | None = None, +) -> int | None: + """Get the quantization axis for a specific layer from layer_info. + + Args: + layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names to layer configuration. + name (str | None): Name of the tensor. + default_axis (int | None): Default axis if not specified. Defaults to None. + + Returns: + int: Quantization axis to use. + """ + if layer_info and name in layer_info: + return layer_info[name].get("axis", default_axis) + return default_axis + + def _next_block_size_multiple(x: float, block_size: int) -> float: return math.ceil(x / block_size) * block_size @@ -229,12 +280,25 @@ def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarr def reshape_scales_for_per_channel_nodes( - scales_map: dict[str, np.ndarray], block_size: int, precision_info: dict[str, int] | None = None + scales_map: dict[str, np.ndarray], + block_size: int, + layer_info: dict[str, dict] | None = None, ): - """Update the scale map for per-channel nodes. For per channel quantization the scale needs to be 1D.""" + """Update the scale map for per-channel nodes. For per channel quantization the scale needs to be 1D. + + Args: + scales_map (dict[str, np.ndarray]): Dictionary mapping weight names to scale arrays. + layer_info (dict[str, dict] | None): Optional dictionary mapping tensor names + to layer configuration dict. + + Returns: + dict[str, np.ndarray]: Updated scales map. + """ for name in scales_map: - num_bits = get_num_bits(precision_info, name) - is_per_channel = (block_size == -1) or (num_bits == 8) + layer_block_size = block_size + if layer_info and name in layer_info and isinstance(layer_info[name], dict): + layer_block_size = layer_info[name].get("block_size", block_size) + is_per_channel = layer_block_size == -1 scales_map[name] = scales_map[name].reshape(-1) if is_per_channel else scales_map[name] return scales_map From 1fd797364cfd2ab119180a04344d0cdb58a453d5 Mon Sep 17 00:00:00 2001 From: ynankani-nv Date: Tue, 14 Oct 2025 22:36:05 +0530 Subject: [PATCH 2/2] Update modelopt/onnx/quantization/int4.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: ynankani-nv --- modelopt/onnx/quantization/int4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index cfb3c1828..447bf2786 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -982,7 +982,7 @@ def _quantize_awq_lite( """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" logger.info("Quantizing model using AWQ lite algorithm") t = time.time() - layer_info = get_layer_info(onnx_model, nodes_to_exclude, **kwargs) + layer_info = get_layer_info(onnx_model, nodes_to_exclude, block_size, **kwargs) run_per_subgraph = kwargs.get("awqlite_run_per_subgraph", False) fuse_nodes = kwargs.get("awqlite_fuse_nodes", True)