Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/windows/onnx_ptq/genai_llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ The table below lists key command-line arguments of the ONNX PTQ example script.
| `--awqclip_bsz_col` | 1024 (default) | Chunk size in columns during weight clipping, user-defined |
| `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [dml,cpu]) | List of execution-providers to use for session run during calibration |
| `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data|
| `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization|
| `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy|

Run the following command to view all available parameters in the script:

Expand Down
24 changes: 22 additions & 2 deletions examples/windows/onnx_ptq/genai_llm/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,15 @@ def main(args):
# device = torch.device(f"cuda:{device_id}")
device = torch.device(args.device)

if args.layers_8bit:
args.enable_mixed_quant = True

print(
f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, "
f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, "
f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, "
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32}\n"
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}, "
f"layers_8bit={args.layers_8bit}\n"
)

print(
Expand Down Expand Up @@ -435,6 +439,8 @@ def main(args):
awqclip_alpha_step=args.awqclip_alpha_step,
awqclip_alpha_min=args.awqclip_alpha_min,
awqclip_bsz_col=args.awqclip_bsz_col,
enable_mixed_quant=args.enable_mixed_quant,
layers_8bit=args.layers_8bit,
)
logging.info(f"\nQuantization process took {time.time() - t} seconds")

Expand Down Expand Up @@ -594,6 +600,20 @@ def main(args):
default=False,
action="store_true",
)

parser.add_argument(
"--enable_mixed_quant",
default=False,
action="store_true",
help=(
"Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd attn, "
"mlp layers quantized to 8 bits; others to 4 bits."
),
)
parser.add_argument(
"--layers_8bit",
type=str,
default="",
help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"),
)
args = parser.parse_args()
main(args)
216 changes: 216 additions & 0 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
from collections import defaultdict
from functools import reduce
from typing import Any, cast

import numpy as np
import onnx
Expand Down Expand Up @@ -625,6 +626,221 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None)
return nodes_to_exclude


def _find_int4_quantizable_weights(
graph: onnx.GraphProto,
nodes_to_exclude: list[str],
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]:
"""Finds the int4 quantizable weights from the graph."""
wa_pack = []
gemm_nodes = [
node
for node in graph.node
if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude
]
initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)}
for gemm in gemm_nodes:
if gemm.input[0] in initializer_idxs:
# Ex. two const input to MatMul_115 in fastvit0.onnx
# Note. RTN algorithm will quantize these weights though
continue

if gemm.input[1] not in initializer_idxs:
continue

weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]]
if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported
continue

gemm_io_type = cast("int", weight_tensor.data_type)

act_tensor = onnx.helper.ValueInfoProto()
act_tensor.name = gemm.input[0]

# TODO: support transA by transposing activation tensors in _clip_search
do_transpose = gemm.op_type == "Gemm" and any(
attr.name == "transB" and attr.i > 0 for attr in gemm.attribute
)

wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type))

return wa_pack


def should_quantize_to_8bit(layer_name: str, layers_8bit: list[str]):
"""Check if layer should be quantized to 8 bits.

The layers_8bit list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.

To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.

This allows us to correctly identify which weights should be quantized to 8 bits.
"""
if not layers_8bit:
return False

# Normalize both to dot-delimited tokens and require exact token sequence match.
def tokens(s: str) -> list[str]:
return s.lstrip("/").replace("/", ".").split(".")

hay = tokens(layer_name)
for pat in layers_8bit:
needle = tokens(pat)
n, m = len(hay), len(needle)
for i in range(n - m + 1):
if hay[i : i + m] == needle:
return True
return False


def validate_8bit_layers(layers_str: str) -> bool:
"""Validate the format of layers_8bit string."""
if not layers_str:
return True
# Allow comma-separated list of path-like tokens
pattern = r"^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$"
return bool(re.match(pattern, layers_str))


def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
precision_pattern_8bit: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
"""Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model.

Args:
onnx_model (onnx.ModelProto): The ONNX model to analyze.
precision_pattern_8bit (str, optional): Comma-separated string of layer patterns to quantize to 8 bits.
If None, a default set of patterns is used to select layers for 8 bits quantization.
nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization.
Defaults to [r"/lm_head"].

Returns:
dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "8"}).
"""
graph = onnx_model.graph

nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
wa_pack = _find_int4_quantizable_weights(graph, nodes_to_exclude)

if precision_pattern_8bit:
if not validate_8bit_layers(precision_pattern_8bit):
raise ValueError("Invalid format for --layers_8bit. Use comma-separated layers.")
layers_list_8bit = [x.strip() for x in precision_pattern_8bit.split(",") if x.strip()]

else:
matmul_nodes = [
node
for node in onnx_model.graph.node
if node.op_type == "MatMul" and "lm_head" not in node.name
]

# Only include nodes matching the specified patterns for all layers present in the model
# For example, for all i where a node exists with name:
# /model/layers.{i}/attn/qkv_proj/MatMul
# /model/layers.{i}/attn/v_proj/MatMul
# /model/layers.{i}/mlp/down_proj/MatMul
pattern_regexes = [
re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"),
re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"),
re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"),
]

# Filter matmul_nodes to only those matching the patterns
filtered_matmul_nodes = []
for node in matmul_nodes:
for pat in pattern_regexes:
if pat.match(node.name):
filtered_matmul_nodes.append(node)
break

# Build a mapping from group key to list of node names (ordered by layer index if possible)
def extract_group_key(node_name):
# Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul
parts = node_name.split("/")
if len(parts) >= 3:
return ".".join(parts[-3:-1])
return node_name

group_to_nodes = {}
for node in filtered_matmul_nodes:
group_key = extract_group_key(node.name)
group_to_nodes.setdefault(group_key, []).append(node.name)

layers_8bit_set = set()
for names in group_to_nodes.values():
n = len(names)
if n == 0:
continue

# Try to sort by layer index if present
def layer_idx(name):
m = re.search(r"layers\.(\d+)\.", name)
return int(m.group(1)) if m else 0

names_sorted = sorted(names, key=layer_idx)
first_eighth = int(n // 8)
last_eighth = int(n // 8)
# First 1/8
layers_8bit_set.update(names_sorted[:first_eighth])
# Last 1/8
if last_eighth > 0:
layers_8bit_set.update(names_sorted[-last_eighth:])
# Every third in the rest (excluding first and last eighth)
rest_start = first_eighth
rest_end = n - last_eighth
for i in range(rest_start, rest_end):
if (i - rest_start) % 3 == 0:
layers_8bit_set.add(names_sorted[i])
layers_list_8bit = list(layers_8bit_set)

# NEW: Create precision info mapping
precision_info = {}
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
weight_name = weight_tensor.name
if should_quantize_to_8bit(weight_name, layers_list_8bit):
precision_info[weight_name] = 8
else:
precision_info[weight_name] = 4
return precision_info


def get_precision_info(
onnx_model: onnx.ModelProto,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
**kwargs: Any,
):
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).

This function determines the quantization precision for each weight tensor in the ONNX model,
based on the provided configuration. If mixed quantization is enabled, it uses the layer
precision mapping; otherwise, it returns None.

Args:
onnx_model (onnx.ModelProto): The ONNX model to analyze.
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
**kwargs: Additional keyword arguments, such as:
- enable_mixed_quant (bool): Whether to enable mixed quantization.
- layers_8bit (str): Comma-separated list of layer patterns to quantize to 8 bit.

Returns:
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
or None if mixed quantization is not enabled.
"""
precision_info = None
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
layers_8bit = kwargs.get("layers_8bit")
if enable_mixed_quant:
precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude)
else:
precision_info = None
return precision_info


def expand_node_names_from_patterns(
graph: onnx.GraphProto | Graph, name_patterns: list[str] | None = None
) -> list[str]:
Expand Down
Loading
Loading