Skip to content

Commit 5128fe7

Browse files
ynankaniyeyu-nvidia
authored andcommitted
[5506930]Add support in ModelOpt for generating mixed-precision (INT4… (#310)
Signed-off-by: unknown <[email protected]> Signed-off-by: Ye Yu <[email protected]>
1 parent 55c5b4b commit 5128fe7

File tree

6 files changed

+653
-212
lines changed

6 files changed

+653
-212
lines changed

examples/windows/onnx_ptq/genai_llm/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ The table below lists key command-line arguments of the ONNX PTQ example script.
5656
| `--awqclip_bsz_col` | 1024 (default) | Chunk size in columns during weight clipping, user-defined |
5757
| `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [dml,cpu]) | List of execution-providers to use for session run during calibration |
5858
| `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data|
59+
| `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization|
60+
| `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy|
5961

6062
Run the following command to view all available parameters in the script:
6163

examples/windows/onnx_ptq/genai_llm/quantize.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,15 @@ def main(args):
361361
# device = torch.device(f"cuda:{device_id}")
362362
device = torch.device(args.device)
363363

364+
if args.layers_8bit:
365+
args.enable_mixed_quant = True
366+
364367
print(
365368
f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, "
366369
f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, "
367370
f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, "
368-
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32}\n"
371+
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}, "
372+
f"layers_8bit={args.layers_8bit}\n"
369373
)
370374

371375
print(
@@ -435,6 +439,8 @@ def main(args):
435439
awqclip_alpha_step=args.awqclip_alpha_step,
436440
awqclip_alpha_min=args.awqclip_alpha_min,
437441
awqclip_bsz_col=args.awqclip_bsz_col,
442+
enable_mixed_quant=args.enable_mixed_quant,
443+
layers_8bit=args.layers_8bit,
438444
)
439445
logging.info(f"\nQuantization process took {time.time() - t} seconds")
440446

@@ -594,6 +600,20 @@ def main(args):
594600
default=False,
595601
action="store_true",
596602
)
597-
603+
parser.add_argument(
604+
"--enable_mixed_quant",
605+
default=False,
606+
action="store_true",
607+
help=(
608+
"Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd attn, "
609+
"mlp layers quantized to 8 bits; others to 4 bits."
610+
),
611+
)
612+
parser.add_argument(
613+
"--layers_8bit",
614+
type=str,
615+
default="",
616+
help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"),
617+
)
598618
args = parser.parse_args()
599619
main(args)

modelopt/onnx/quantization/graph_utils.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import re
1919
from collections import defaultdict
2020
from functools import reduce
21+
from typing import Any, cast
2122

2223
import numpy as np
2324
import onnx
@@ -625,6 +626,221 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None)
625626
return nodes_to_exclude
626627

627628

629+
def _find_int4_quantizable_weights(
630+
graph: onnx.GraphProto,
631+
nodes_to_exclude: list[str],
632+
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]:
633+
"""Finds the int4 quantizable weights from the graph."""
634+
wa_pack = []
635+
gemm_nodes = [
636+
node
637+
for node in graph.node
638+
if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude
639+
]
640+
initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)}
641+
for gemm in gemm_nodes:
642+
if gemm.input[0] in initializer_idxs:
643+
# Ex. two const input to MatMul_115 in fastvit0.onnx
644+
# Note. RTN algorithm will quantize these weights though
645+
continue
646+
647+
if gemm.input[1] not in initializer_idxs:
648+
continue
649+
650+
weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]]
651+
if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported
652+
continue
653+
654+
gemm_io_type = cast("int", weight_tensor.data_type)
655+
656+
act_tensor = onnx.helper.ValueInfoProto()
657+
act_tensor.name = gemm.input[0]
658+
659+
# TODO: support transA by transposing activation tensors in _clip_search
660+
do_transpose = gemm.op_type == "Gemm" and any(
661+
attr.name == "transB" and attr.i > 0 for attr in gemm.attribute
662+
)
663+
664+
wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type))
665+
666+
return wa_pack
667+
668+
669+
def should_quantize_to_8bit(layer_name: str, layers_8bit: list[str]):
670+
"""Check if layer should be quantized to 8 bits.
671+
672+
The layers_8bit list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
673+
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
674+
675+
To match these, we:
676+
- Remove the leading slash from the node name.
677+
- Replace all '/' with '.' to match the naming convention of the initializer.
678+
679+
This allows us to correctly identify which weights should be quantized to 8 bits.
680+
"""
681+
if not layers_8bit:
682+
return False
683+
684+
# Normalize both to dot-delimited tokens and require exact token sequence match.
685+
def tokens(s: str) -> list[str]:
686+
return s.lstrip("/").replace("/", ".").split(".")
687+
688+
hay = tokens(layer_name)
689+
for pat in layers_8bit:
690+
needle = tokens(pat)
691+
n, m = len(hay), len(needle)
692+
for i in range(n - m + 1):
693+
if hay[i : i + m] == needle:
694+
return True
695+
return False
696+
697+
698+
def validate_8bit_layers(layers_str: str) -> bool:
699+
"""Validate the format of layers_8bit string."""
700+
if not layers_str:
701+
return True
702+
# Allow comma-separated list of path-like tokens
703+
pattern = r"^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$"
704+
return bool(re.match(pattern, layers_str))
705+
706+
707+
def get_layer_precision_mapping(
708+
onnx_model: onnx.ModelProto,
709+
precision_pattern_8bit: str | None = None,
710+
nodes_to_exclude: list[str] | None = [r"/lm_head"],
711+
):
712+
"""Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model.
713+
714+
Args:
715+
onnx_model (onnx.ModelProto): The ONNX model to analyze.
716+
precision_pattern_8bit (str, optional): Comma-separated string of layer patterns to quantize to 8 bits.
717+
If None, a default set of patterns is used to select layers for 8 bits quantization.
718+
nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization.
719+
Defaults to [r"/lm_head"].
720+
721+
Returns:
722+
dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "8"}).
723+
"""
724+
graph = onnx_model.graph
725+
726+
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
727+
# Collect quantizable weight tensors
728+
wa_pack = _find_int4_quantizable_weights(graph, nodes_to_exclude)
729+
730+
if precision_pattern_8bit:
731+
if not validate_8bit_layers(precision_pattern_8bit):
732+
raise ValueError("Invalid format for --layers_8bit. Use comma-separated layers.")
733+
layers_list_8bit = [x.strip() for x in precision_pattern_8bit.split(",") if x.strip()]
734+
735+
else:
736+
matmul_nodes = [
737+
node
738+
for node in onnx_model.graph.node
739+
if node.op_type == "MatMul" and "lm_head" not in node.name
740+
]
741+
742+
# Only include nodes matching the specified patterns for all layers present in the model
743+
# For example, for all i where a node exists with name:
744+
# /model/layers.{i}/attn/qkv_proj/MatMul
745+
# /model/layers.{i}/attn/v_proj/MatMul
746+
# /model/layers.{i}/mlp/down_proj/MatMul
747+
pattern_regexes = [
748+
re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"),
749+
re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"),
750+
re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"),
751+
]
752+
753+
# Filter matmul_nodes to only those matching the patterns
754+
filtered_matmul_nodes = []
755+
for node in matmul_nodes:
756+
for pat in pattern_regexes:
757+
if pat.match(node.name):
758+
filtered_matmul_nodes.append(node)
759+
break
760+
761+
# Build a mapping from group key to list of node names (ordered by layer index if possible)
762+
def extract_group_key(node_name):
763+
# Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul
764+
parts = node_name.split("/")
765+
if len(parts) >= 3:
766+
return ".".join(parts[-3:-1])
767+
return node_name
768+
769+
group_to_nodes = {}
770+
for node in filtered_matmul_nodes:
771+
group_key = extract_group_key(node.name)
772+
group_to_nodes.setdefault(group_key, []).append(node.name)
773+
774+
layers_8bit_set = set()
775+
for names in group_to_nodes.values():
776+
n = len(names)
777+
if n == 0:
778+
continue
779+
780+
# Try to sort by layer index if present
781+
def layer_idx(name):
782+
m = re.search(r"layers\.(\d+)\.", name)
783+
return int(m.group(1)) if m else 0
784+
785+
names_sorted = sorted(names, key=layer_idx)
786+
first_eighth = int(n // 8)
787+
last_eighth = int(n // 8)
788+
# First 1/8
789+
layers_8bit_set.update(names_sorted[:first_eighth])
790+
# Last 1/8
791+
if last_eighth > 0:
792+
layers_8bit_set.update(names_sorted[-last_eighth:])
793+
# Every third in the rest (excluding first and last eighth)
794+
rest_start = first_eighth
795+
rest_end = n - last_eighth
796+
for i in range(rest_start, rest_end):
797+
if (i - rest_start) % 3 == 0:
798+
layers_8bit_set.add(names_sorted[i])
799+
layers_list_8bit = list(layers_8bit_set)
800+
801+
# NEW: Create precision info mapping
802+
precision_info = {}
803+
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
804+
weight_name = weight_tensor.name
805+
if should_quantize_to_8bit(weight_name, layers_list_8bit):
806+
precision_info[weight_name] = 8
807+
else:
808+
precision_info[weight_name] = 4
809+
return precision_info
810+
811+
812+
def get_precision_info(
813+
onnx_model: onnx.ModelProto,
814+
nodes_to_exclude: list[str] | None = [r"/lm_head"],
815+
**kwargs: Any,
816+
):
817+
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).
818+
819+
This function determines the quantization precision for each weight tensor in the ONNX model,
820+
based on the provided configuration. If mixed quantization is enabled, it uses the layer
821+
precision mapping; otherwise, it returns None.
822+
823+
Args:
824+
onnx_model (onnx.ModelProto): The ONNX model to analyze.
825+
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
826+
**kwargs: Additional keyword arguments, such as:
827+
- enable_mixed_quant (bool): Whether to enable mixed quantization.
828+
- layers_8bit (str): Comma-separated list of layer patterns to quantize to 8 bit.
829+
830+
Returns:
831+
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
832+
or None if mixed quantization is not enabled.
833+
"""
834+
precision_info = None
835+
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
836+
layers_8bit = kwargs.get("layers_8bit")
837+
if enable_mixed_quant:
838+
precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude)
839+
else:
840+
precision_info = None
841+
return precision_info
842+
843+
628844
def expand_node_names_from_patterns(
629845
graph: onnx.GraphProto | Graph, name_patterns: list[str] | None = None
630846
) -> list[str]:

0 commit comments

Comments
 (0)