Skip to content

Commit f5c209d

Browse files
Cleanup mixed precision and gather node layer info mapping (#434)
Signed-off-by: unknown <[email protected]> Signed-off-by: ynankani-nv <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 99c76ff commit f5c209d

File tree

4 files changed

+192
-81
lines changed

4 files changed

+192
-81
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
save_onnx,
4141
)
4242

43+
DEFAULT_GATHER_BLOCK_SIZE = 32
44+
DEFAULT_GATHER_QUANTIZE_AXIS = None
45+
4346

4447
def is_const_input(tensor: Tensor) -> bool:
4548
"""Returns whether the given tensor is an initializer or produced by const-foldable nodes."""
@@ -718,6 +721,8 @@ def get_layer_precision_mapping(
718721
onnx_model: onnx.ModelProto,
719722
precision_pattern_8bit: str | None = None,
720723
nodes_to_exclude: list[str] | None = [r"/lm_head"],
724+
block_size: int = 128,
725+
quantize_axis: int = 0,
721726
):
722727
"""Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model.
723728
@@ -746,7 +751,7 @@ def get_layer_precision_mapping(
746751
matmul_nodes = [
747752
node
748753
for node in onnx_model.graph.node
749-
if node.op_type == "MatMul" and "lm_head" not in node.name
754+
if node.op_type in ["Gemm", "MatMul"] and "lm_head" not in node.name
750755
]
751756

752757
# Only include nodes matching the specified patterns for all layers present in the model
@@ -808,47 +813,81 @@ def layer_idx(name):
808813
layers_8bit_set.add(names_sorted[i])
809814
layers_list_8bit = list(layers_8bit_set)
810815

811-
# NEW: Create precision info mapping
812-
precision_info = {}
816+
# NEW: Create layer info mapping with precision, block_size, and axis
817+
layer_info = {}
813818
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
814819
weight_name = weight_tensor.name
815820
if should_quantize_to_8bit(weight_name, layers_list_8bit):
816-
precision_info[weight_name] = 8
821+
layer_info[weight_name] = {
822+
"precision": 8,
823+
"block_size": -1, # Per-channel for 8-bit
824+
"axis": 0,
825+
}
817826
else:
818-
precision_info[weight_name] = 4
819-
return precision_info
827+
layer_info[weight_name] = {
828+
"precision": 4,
829+
"block_size": block_size, # Default block size for 4-bit
830+
"axis": quantize_axis,
831+
}
832+
833+
return layer_info
820834

821835

822-
def get_precision_info(
836+
def get_layer_info(
823837
onnx_model: onnx.ModelProto,
824838
nodes_to_exclude: list[str] | None = [r"/lm_head"],
839+
block_size: int = 128,
840+
quantize_axis: int = 0,
825841
**kwargs: Any,
826842
):
827-
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).
843+
"""Generate a mapping of weight tensor names to their quantization configuration.
828844
829-
This function determines the quantization precision for each weight tensor in the ONNX model,
830-
based on the provided configuration. If mixed quantization is enabled, it uses the layer
831-
precision mapping; otherwise, it returns None.
845+
This function determines the quantization configuration (precision, block_size, axis) for each
846+
weight tensor in the ONNX model, based on the provided configuration. If mixed quantization
847+
is enabled, it uses the layer precision mapping; otherwise, it returns None.
832848
833849
Args:
834850
onnx_model (onnx.ModelProto): The ONNX model to analyze.
835851
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
836852
**kwargs: Additional keyword arguments, such as:
837853
- enable_mixed_quant (bool): Whether to enable mixed quantization.
838854
- layers_8bit (str): Comma-separated list of layer patterns to quantize to 8 bit.
855+
- block_size (int): Default block size for quantization.
856+
- quantize_axis (int): Default quantization axis.
857+
- gather_block_size (int): Default block size for gather quantization.
858+
- gather_quantize_axis (int): Default quantization axis for gather.
839859
840860
Returns:
841-
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
842-
or None if mixed quantization is not enabled.
861+
dict[str, dict[str, Any]] | None: A mapping from weight tensor names to their quantization
862+
configuration (with keys: precision, block_size, axis), or None if mixed quantization is not enabled.
843863
"""
844-
precision_info = None
864+
layer_info = None
845865
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
846866
layers_8bit = kwargs.get("layers_8bit")
867+
gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE)
868+
gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS)
847869
if enable_mixed_quant:
848-
precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude)
870+
layer_info = get_layer_precision_mapping(
871+
onnx_model,
872+
layers_8bit,
873+
nodes_to_exclude,
874+
block_size,
875+
quantize_axis,
876+
)
849877
else:
850-
precision_info = None
851-
return precision_info
878+
layer_info = None
879+
880+
if gather_quantize_axis is not None:
881+
if layer_info is None:
882+
layer_info = {}
883+
for node in onnx_model.graph.node:
884+
if node.op_type == "Gather":
885+
layer_info[node.input[0]] = {
886+
"precision": 4,
887+
"block_size": gather_block_size,
888+
"axis": gather_quantize_axis,
889+
}
890+
return layer_info
852891

853892

854893
def expand_node_names_from_patterns(

0 commit comments

Comments
 (0)