Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/windows/onnx_ptq/genai_llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The table below lists key command-line arguments of the ONNX PTQ example script.
| `--awqclip_bsz_col` | 1024 (default) | Chunk size in columns during weight clipping, user-defined |
| `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [dml,cpu]) | List of execution-providers to use for session run during calibration |
| `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data|
| `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization|

Run the following command to view all available parameters in the script:

Expand Down
20 changes: 18 additions & 2 deletions examples/windows/onnx_ptq/genai_llm/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def main(args):
f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, "
f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, "
f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, "
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32}\n"
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}\n"
)

print(
Expand Down Expand Up @@ -435,6 +435,8 @@ def main(args):
awqclip_alpha_step=args.awqclip_alpha_step,
awqclip_alpha_min=args.awqclip_alpha_min,
awqclip_bsz_col=args.awqclip_bsz_col,
enable_mixed_quant=args.enable_mixed_quant,
int8_layers=args.int8_layers,
)
logging.info(f"\nQuantization process took {time.time() - t} seconds")

Expand Down Expand Up @@ -594,6 +596,20 @@ def main(args):
default=False,
action="store_true",
)

parser.add_argument(
"--enable_mixed_quant",
default=False,
action="store_true",
help="True when we want to use mixed quantization",
Copy link
Contributor

@i-riyad i-riyad Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: help="Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd layer quantized to INT8; others to INT4.",

)
parser.add_argument(
"--int8_layers",
Comment on lines +600 to +606
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use just assume mixed_quant enabled if --int8_layers is non-empty? And remove the --enable_mixed_quant option?

Copy link
Author

@ynankani ynankani Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if --int8_layers is specified we will select only those layers to be quantized in int8 which match the patter in int8_layers, else if only --enable_mixed_quant is specified we hardcode select few important layers similar to what some other quantization tools like model builder/llama.cpp are doing.
example:
if "python quantize.py ... -int8_layer="layer.0" --enable_mixed_quant" => all layer.0 will be quantized to int8
else "python quantize.py ... --enable_mixed_quant" => quantize to int8 the first 1/8, last 1/8 and every 3rd layer for below attn and ffn layers.
/model/layers.{i}/attn/qkv_proj/MatMul
/model/layers.{i}/attn/v_proj/MatMul
/model/layers.{i}/mlp/down_proj/MatMul

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, make sense now.

type=str,
default="",
help=(
"Comma-separated list of layer patterns to quantize to INT8 instead of INT4."
"Example: 'layers.0,layers.1,lm_head'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Overrides default mixed quant strategy. Example: 'layers.0,lm_head

),
)
args = parser.parse_args()
main(args)
218 changes: 218 additions & 0 deletions modelopt/onnx/quantization/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
from collections import defaultdict
from functools import reduce
from typing import Any, cast

import numpy as np
import onnx
Expand Down Expand Up @@ -625,6 +626,223 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None)
return nodes_to_exclude


def _find_quantizable_weights(
graph: onnx.GraphProto,
nodes_to_exclude: list[str],
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]:
"""Finds the quantizable weights from the graph."""
wa_pack = []
gemm_nodes = [
node
for node in graph.node
if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude
]
initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)}
for gemm in gemm_nodes:
if gemm.input[0] in initializer_idxs:
# Ex. two const input to MatMul_115 in fastvit0.onnx
# Note. RTN algorithm will quantize these weights though
continue

if gemm.input[1] not in initializer_idxs:
continue

weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]]
if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported
continue

gemm_io_type = cast("int", weight_tensor.data_type)

act_tensor = onnx.helper.ValueInfoProto()
act_tensor.name = gemm.input[0]

# TODO: support transA by transposing activation tensors in _clip_search
do_transpose = gemm.op_type == "Gemm" and any(
attr.name == "transB" and attr.i > 0 for attr in gemm.attribute
)

wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type))

return wa_pack


def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
"""Check if layer should be quantized to INT8.

The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.

To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.

This allows us to correctly identify which weights should be quantized to INT8.
"""
if not int8_layers:
return False

# Normalize both to dot-delimited tokens and require exact token sequence match.
def tokens(s: str) -> list[str]:
return s.lstrip("/").replace("/", ".").split(".")

hay = tokens(layer_name)
for pat in int8_layers:
needle = tokens(pat)
n, m = len(hay), len(needle)
for i in range(n - m + 1):
if hay[i : i + m] == needle:
return True
return False


def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re

pattern = r"^[a-zA-Z0-9_.,\-]$"
return bool(re.match(pattern, layers_str))

Comment on lines +698 to +707
Copy link

@coderabbitai coderabbitai bot Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix regex validation to accept real layer patterns.

The current pattern r"^[a-zA-Z0-9_.,\-]$ only matches a single character and disallows /. This will reject valid inputs like /model/layers.13/attn/qkv_proj/MatMul.

Apply:

-def validate_int8_layers(layers_str: str) -> bool:
+def validate_int8_layers(layers_str: str) -> bool:
@@
-    pattern = r"^[a-zA-Z0-9_.,\-]$"
-    return bool(re.match(pattern, layers_str))
+    # Allow comma-separated list of path-like tokens
+    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
+    return bool(re.match(pattern, layers_str))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re
pattern = r"^[a-zA-Z0-9_.,\-]$"
return bool(re.match(pattern, layers_str))
def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re
# Allow comma-separated list of path-like tokens
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
return bool(re.match(pattern, layers_str))
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 698 to 707, the regex
only matches a single character and disallows "/" so valid layer paths are
rejected; update the validation to allow one-or-more characters and include "/"
(and keep letters, digits, underscore, dot, comma and hyphen), e.g. replace the
pattern with a character-class that includes "/" and a quantifier (and ensure
hyphen is either escaped or last), and use re.fullmatch instead of re.match so
the entire string is validated (e.g. pattern like "^[A-Za-z0-9_/.,-]+$" with
re.fullmatch).

Copy link
Contributor

@i-riyad i-riyad Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look into this suggestion

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

@i-riyad I've analyzed the suggestion and it's correct - there's a critical bug in the regex validation.

The CodeRabbit suggestion is absolutely correct. The current regex has two major flaws:

  1. Single character matching: The pattern ^[a-zA-Z0-9_.,\-]$ only matches exactly one character due to the absence of quantifiers
  2. Missing forward slash: Real ONNX node names like /model/layers.13/attn/qkv_proj/MatMul contain forward slashes which aren't allowed

The suggested fix properly addresses both issues by:

  • Adding + quantifier to match one or more characters
  • Including / in the character class
  • Supporting comma-separated values with optional whitespace

However, I'd recommend a small improvement to make the validation more robust:

def validate_int8_layers(layers_str: str) -> bool:
    """Validate the format of int8_layers string."""
    if not layers_str:
        return True
    # Allow comma-separated list of path-like tokens
    # Pattern allows: letters, digits, underscore, dot, slash, hyphen
    # Supports comma-separated values with optional whitespace
    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
    return bool(re.fullmatch(pattern, layers_str))

Using re.fullmatch() instead of re.match() provides clearer intent (matching the entire string) and is slightly more explicit, though functionally equivalent when using ^ and $ anchors.

The suggestion should definitely be implemented to fix this validation bug.


🏁 Script executed:

#!/bin/bash
# Let's examine the current validate_int8_layers function and test the regex behavior
cd /tmp
cat > test_regex.py << 'EOF'
import re

# Current implementation
def validate_int8_layers_current(layers_str: str) -> bool:
    if not layers_str:
        return True
    pattern = r"^[a-zA-Z0-9_.,\-]$"
    return bool(re.match(pattern, layers_str))

# Suggested implementation
def validate_int8_layers_suggested(layers_str: str) -> bool:
    if not layers_str:
        return True
    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
    return bool(re.match(pattern, layers_str))

# Test cases
test_cases = [
    "/model/layers.13/attn/qkv_proj/MatMul",
    "/model/layers.0/attn/v_proj/MatMul,/model/layers.1/mlp/down_proj/MatMul",
    "model.layers.13.attn.qkv_proj.MatMul.weight",
    "a",  # single character - should pass for current
    "",   # empty string - should pass for both
    "/invalid/path with spaces/MatMul",  # spaces in path - should fail
    "valid,path",  # simple comma-separated - should pass for suggested
]

print("Testing current vs suggested implementation:")
print("=" * 60)
for test in test_cases:
    current_result = validate_int8_layers_current(test)
    suggested_result = validate_int8_layers_suggested(test)
    print(f"Input: '{test}'")
    print(f"  Current:   {current_result}")
    print(f"  Suggested: {suggested_result}")
    print()
EOF

python test_regex.py

Length of output: 634


Perfect! The test results confirm the analysis. The current implementation is severely broken - it only accepts single characters or empty strings, making it useless for validating real layer patterns.

Key findings:

  • ✅ Current regex rejects ALL valid layer patterns like /model/layers.13/attn/qkv_proj/MatMul
  • ✅ Suggested fix correctly accepts valid patterns while rejecting invalid ones (like paths with spaces)
  • ✅ Both handle empty strings correctly
  • ✅ Suggested pattern properly supports comma-separated values

Recommendation: Implement the CodeRabbit suggestion immediately. This is a critical bug that makes the validation function completely non-functional for its intended purpose.

Here's the corrected implementation:

def validate_int8_layers(layers_str: str) -> bool:
    """Validate the format of int8_layers string."""
    if not layers_str:
        return True
    # Basic validation: check for valid characters and structure
    import re

    # Allow comma-separated list of path-like tokens
    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
    return bool(re.match(pattern, layers_str))

The current regex would cause the mixed-precision quantization feature to fail silently or reject all user-provided layer specifications, which would be a significant usability issue.


def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model.
Comment on lines +709 to +714
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid mutable default args and normalize exclusions.

Using a list default (nodes_to_exclude=[r"/lm_head"]) risks cross-call mutation. Also normalize nodes_to_exclude up-front.

Apply:

-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
     onnx_model: onnx.ModelProto,
     int8_precision_pattern: str | None = None,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
 ):
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)

Also applies to: 728-731

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 709-714 (and similarly
at 728-731), avoid using a mutable list as a default for nodes_to_exclude:
change the signature to use nodes_to_exclude: list[str] | None = None, then
inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is
None else list(nodes_to_exclude) to prevent shared-state mutations; immediately
normalize entries (e.g., strip/ensure raw strings or compiled regex as the code
expects) so downstream logic can assume a consistent list type and format.


Args:
onnx_model (onnx.ModelProto): The ONNX model to analyze.
int8_precision_pattern (str, optional): Comma-separated string of layer patterns to quantize to INT8.
If None, a default set of patterns is used to select layers for INT8 quantization.
nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization.
Defaults to [r"/lm_head"].

Returns:
dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "int8"}).
"""
graph = onnx_model.graph

nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
wa_pack = _find_quantizable_weights(graph, nodes_to_exclude)

if int8_precision_pattern:
if not validate_int8_layers(int8_precision_pattern):
raise ValueError("Invalid format for --int8_layers. Use comma-separated layers.")
int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()]

else:
matmul_nodes = [
node
for node in onnx_model.graph.node
if node.op_type == "MatMul" and "lm_head" not in node.name
]

# Only include nodes matching the specified patterns for all layers present in the model
# For example, for all i where a node exists with name:
# /model/layers.{i}/attn/qkv_proj/MatMul
# /model/layers.{i}/attn/v_proj/MatMul
# /model/layers.{i}/mlp/down_proj/MatMul
pattern_regexes = [
re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"),
re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"),
re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"),
]

# Filter matmul_nodes to only those matching the patterns
filtered_matmul_nodes = []
for node in matmul_nodes:
for pat in pattern_regexes:
if pat.match(node.name):
filtered_matmul_nodes.append(node)
break

# Build a mapping from group key to list of node names (ordered by layer index if possible)
def extract_group_key(node_name):
# Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul
parts = node_name.split("/")
if len(parts) >= 3:
return ".".join(parts[-3:-1])
return node_name

group_to_nodes = {}
for node in filtered_matmul_nodes:
group_key = extract_group_key(node.name)
group_to_nodes.setdefault(group_key, []).append(node.name)

int8_layers_set = set()
for names in group_to_nodes.values():
n = len(names)
if n == 0:
continue

# Try to sort by layer index if present
def layer_idx(name):
m = re.search(r"layers\.(\d+)\.", name)
return int(m.group(1)) if m else 0

names_sorted = sorted(names, key=layer_idx)
first_eighth = int(n // 8)
last_eighth = int(n // 8)
# First 1/8
int8_layers_set.update(names_sorted[:first_eighth])
# Last 1/8
if last_eighth > 0:
int8_layers_set.update(names_sorted[-last_eighth:])
# Every third in the rest (excluding first and last eighth)
rest_start = first_eighth
rest_end = n - last_eighth
for i in range(rest_start, rest_end):
if (i - rest_start) % 3 == 0:
int8_layers_set.add(names_sorted[i])
int8_layers_list = list(int8_layers_set)

# NEW: Create precision info mapping
precision_info = {}
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
weight_name = weight_tensor.name
if should_quantize_to_int8(weight_name, int8_layers_list):
precision_info[weight_name] = 8
else:
precision_info[weight_name] = 4
return precision_info


def get_precision_info(
onnx_model: onnx.ModelProto,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
**kwargs: Any,
):
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).

This function determines the quantization precision for each weight tensor in the ONNX model,
based on the provided configuration. If mixed quantization is enabled, it uses the layer
precision mapping; otherwise, it returns None.

Args:
onnx_model (onnx.ModelProto): The ONNX model to analyze.
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
**kwargs: Additional keyword arguments, such as:
- enable_mixed_quant (bool): Whether to enable mixed quantization.
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8.

Returns:
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
or None if mixed quantization is not enabled.
"""
precision_info = None
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
int8_layers = kwargs.get("int8_layers")
if enable_mixed_quant:
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude)
else:
precision_info = None
return precision_info

Comment on lines +814 to +844
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Same mutable default issue in public API.

Apply same fix as above.

-def get_precision_info(
+def get_precision_info(
     onnx_model: onnx.ModelProto,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
     **kwargs: Any,
 ):
@@
-    enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    enable_mixed_quant = kwargs.get("enable_mixed_quant", False)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 814 to 844, the
function get_precision_info declares a mutable default parameter
nodes_to_exclude = [r"/lm_head"], which should be avoided; change the signature
to use nodes_to_exclude: list[str] | None = None (or Optional[list[str]] = None)
and inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude
is None, preserving the original default behavior while avoiding the mutable
default; update any type hints and callers if necessary.


def expand_node_names_from_patterns(
graph: onnx.GraphProto | Graph, name_patterns: list[str] | None = None
) -> list[str]:
Expand Down
Loading
Loading