Skip to content

Commit b6a39be

Browse files
committed
[5506930]Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models, refactored changes and handle comments
Signed-off-by: unknown <[email protected]>
1 parent 024f97a commit b6a39be

File tree

6 files changed

+390
-314
lines changed

6 files changed

+390
-314
lines changed

examples/windows/onnx_ptq/genai_llm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ The table below lists key command-line arguments of the ONNX PTQ example script.
5656
| `--awqclip_bsz_col` | 1024 (default) | Chunk size in columns during weight clipping, user-defined |
5757
| `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [dml,cpu]) | List of execution-providers to use for session run during calibration |
5858
| `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data|
59+
| `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization|
5960

6061
Run the following command to view all available parameters in the script:
6162

examples/windows/onnx_ptq/genai_llm/quantize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def main(args):
365365
f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, "
366366
f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, "
367367
f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, "
368-
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} k_quant_mixed={args.k_quant_mixed}\n"
368+
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}\n"
369369
)
370370

371371
print(
@@ -435,7 +435,7 @@ def main(args):
435435
awqclip_alpha_step=args.awqclip_alpha_step,
436436
awqclip_alpha_min=args.awqclip_alpha_min,
437437
awqclip_bsz_col=args.awqclip_bsz_col,
438-
k_quant_mixed=args.k_quant_mixed,
438+
enable_mixed_quant=args.enable_mixed_quant,
439439
int8_layers=args.int8_layers,
440440
)
441441
logging.info(f"\nQuantization process took {time.time() - t} seconds")
@@ -597,10 +597,10 @@ def main(args):
597597
action="store_true",
598598
)
599599
parser.add_argument(
600-
"--k_quant_mixed",
600+
"--enable_mixed_quant",
601601
default=False,
602602
action="store_true",
603-
help="True when we want to use k_quant_mixed quantization",
603+
help="True when we want to use mixed quantization",
604604
)
605605
parser.add_argument(
606606
"--int8_layers",

modelopt/onnx/quantization/graph_utils.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import re
1919
from collections import defaultdict
2020
from functools import reduce
21+
from typing import Any, cast
2122

2223
import numpy as np
2324
import onnx
@@ -625,6 +626,223 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None)
625626
return nodes_to_exclude
626627

627628

629+
def _find_quantizable_weights(
630+
graph: onnx.GraphProto,
631+
nodes_to_exclude: list[str],
632+
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]:
633+
"""Finds the quantizable weights from the graph."""
634+
wa_pack = []
635+
gemm_nodes = [
636+
node
637+
for node in graph.node
638+
if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude
639+
]
640+
initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)}
641+
for gemm in gemm_nodes:
642+
if gemm.input[0] in initializer_idxs:
643+
# Ex. two const input to MatMul_115 in fastvit0.onnx
644+
# Note. RTN algorithm will quantize these weights though
645+
continue
646+
647+
if gemm.input[1] not in initializer_idxs:
648+
continue
649+
650+
weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]]
651+
if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported
652+
continue
653+
654+
gemm_io_type = cast("int", weight_tensor.data_type)
655+
656+
act_tensor = onnx.helper.ValueInfoProto()
657+
act_tensor.name = gemm.input[0]
658+
659+
# TODO: support transA by transposing activation tensors in _clip_search
660+
do_transpose = gemm.op_type == "Gemm" and any(
661+
attr.name == "transB" and attr.i > 0 for attr in gemm.attribute
662+
)
663+
664+
wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type))
665+
666+
return wa_pack
667+
668+
669+
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
670+
"""Check if layer should be quantized to INT8.
671+
672+
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
673+
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
674+
675+
To match these, we:
676+
- Remove the leading slash from the node name.
677+
- Replace all '/' with '.' to match the naming convention of the initializer.
678+
679+
This allows us to correctly identify which weights should be quantized to INT8.
680+
"""
681+
if not int8_layers:
682+
return False
683+
684+
# Normalize both to dot-delimited tokens and require exact token sequence match.
685+
def tokens(s: str) -> list[str]:
686+
return s.lstrip("/").replace("/", ".").split(".")
687+
688+
hay = tokens(layer_name)
689+
for pat in int8_layers:
690+
needle = tokens(pat)
691+
n, m = len(hay), len(needle)
692+
for i in range(n - m + 1):
693+
if hay[i : i + m] == needle:
694+
return True
695+
return False
696+
697+
698+
def validate_int8_layers(layers_str: str) -> bool:
699+
"""Validate the format of int8_layers string."""
700+
if not layers_str:
701+
return True
702+
# Basic validation: check for valid characters and structure
703+
import re
704+
705+
pattern = r"^[a-zA-Z0-9_.,\-]$"
706+
return bool(re.match(pattern, layers_str))
707+
708+
709+
def get_layer_precision_mapping(
710+
onnx_model: onnx.ModelProto,
711+
int8_precision_pattern: str | None = None,
712+
nodes_to_exclude: list[str] | None = [r"/lm_head"],
713+
):
714+
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model.
715+
716+
Args:
717+
onnx_model (onnx.ModelProto): The ONNX model to analyze.
718+
int8_precision_pattern (str, optional): Comma-separated string of layer patterns to quantize to INT8.
719+
If None, a default set of patterns is used to select layers for INT8 quantization.
720+
nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization.
721+
Defaults to [r"/lm_head"].
722+
723+
Returns:
724+
dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "int8"}).
725+
"""
726+
graph = onnx_model.graph
727+
728+
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
729+
# Collect quantizable weight tensors
730+
wa_pack = _find_quantizable_weights(graph, nodes_to_exclude)
731+
732+
if int8_precision_pattern:
733+
if not validate_int8_layers(int8_precision_pattern):
734+
raise ValueError("Invalid format for --int8_layers. Use comma-separated layers.")
735+
int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()]
736+
737+
else:
738+
matmul_nodes = [
739+
node
740+
for node in onnx_model.graph.node
741+
if node.op_type == "MatMul" and "lm_head" not in node.name
742+
]
743+
744+
# Only include nodes matching the specified patterns for all layers present in the model
745+
# For example, for all i where a node exists with name:
746+
# /model/layers.{i}/attn/qkv_proj/MatMul
747+
# /model/layers.{i}/attn/v_proj/MatMul
748+
# /model/layers.{i}/mlp/down_proj/MatMul
749+
pattern_regexes = [
750+
re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"),
751+
re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"),
752+
re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"),
753+
]
754+
755+
# Filter matmul_nodes to only those matching the patterns
756+
filtered_matmul_nodes = []
757+
for node in matmul_nodes:
758+
for pat in pattern_regexes:
759+
if pat.match(node.name):
760+
filtered_matmul_nodes.append(node)
761+
break
762+
763+
# Build a mapping from group key to list of node names (ordered by layer index if possible)
764+
def extract_group_key(node_name):
765+
# Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul
766+
parts = node_name.split("/")
767+
if len(parts) >= 3:
768+
return ".".join(parts[-3:-1])
769+
return node_name
770+
771+
group_to_nodes = {}
772+
for node in filtered_matmul_nodes:
773+
group_key = extract_group_key(node.name)
774+
group_to_nodes.setdefault(group_key, []).append(node.name)
775+
776+
int8_layers_set = set()
777+
for names in group_to_nodes.values():
778+
n = len(names)
779+
if n == 0:
780+
continue
781+
782+
# Try to sort by layer index if present
783+
def layer_idx(name):
784+
m = re.search(r"layers\.(\d+)\.", name)
785+
return int(m.group(1)) if m else 0
786+
787+
names_sorted = sorted(names, key=layer_idx)
788+
first_eighth = int(n // 8)
789+
last_eighth = int(n // 8)
790+
# First 1/8
791+
int8_layers_set.update(names_sorted[:first_eighth])
792+
# Last 1/8
793+
if last_eighth > 0:
794+
int8_layers_set.update(names_sorted[-last_eighth:])
795+
# Every third in the rest (excluding first and last eighth)
796+
rest_start = first_eighth
797+
rest_end = n - last_eighth
798+
for i in range(rest_start, rest_end):
799+
if (i - rest_start) % 3 == 0:
800+
int8_layers_set.add(names_sorted[i])
801+
int8_layers_list = list(int8_layers_set)
802+
803+
# NEW: Create precision info mapping
804+
precision_info = {}
805+
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
806+
weight_name = weight_tensor.name
807+
if should_quantize_to_int8(weight_name, int8_layers_list):
808+
precision_info[weight_name] = 8
809+
else:
810+
precision_info[weight_name] = 4
811+
return precision_info
812+
813+
814+
def get_precision_info(
815+
onnx_model: onnx.ModelProto,
816+
nodes_to_exclude: list[str] | None = [r"/lm_head"],
817+
**kwargs: Any,
818+
):
819+
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).
820+
821+
This function determines the quantization precision for each weight tensor in the ONNX model,
822+
based on the provided configuration. If mixed quantization is enabled, it uses the layer
823+
precision mapping; otherwise, it returns None.
824+
825+
Args:
826+
onnx_model (onnx.ModelProto): The ONNX model to analyze.
827+
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
828+
**kwargs: Additional keyword arguments, such as:
829+
- enable_mixed_quant (bool): Whether to enable mixed quantization.
830+
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8.
831+
832+
Returns:
833+
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
834+
or None if mixed quantization is not enabled.
835+
"""
836+
precision_info = None
837+
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
838+
int8_layers = kwargs.get("int8_layers")
839+
if enable_mixed_quant:
840+
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude)
841+
else:
842+
precision_info = None
843+
return precision_info
844+
845+
628846
def expand_node_names_from_patterns(
629847
graph: onnx.GraphProto | Graph, name_patterns: list[str] | None = None
630848
) -> list[str]:

0 commit comments

Comments
 (0)