Skip to content

Conversation

ynankani
Copy link

@ynankani ynankani commented Sep 9, 2025

…+INT8) ONNX models

What does this PR do?

Type of change: new feature

Overview: Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models

Usage

python quantize.py --model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --onnx_path=C:\Users\Deepseek_qwen_7b\output\DeepSeek-R1-Distill-Qwen-7B_fp16\model\model.onnx --output_path=C:\Users\Deepseek_qwen_7b\output\Deepseel_r1_distill_qwen_7b_mixed_awq\model.onnx --calib_size=32 --algo=awq_lite --dataset=cnn --calibration_eps=NvTensorRtRtx --no_position_ids **--enable_mixed_quant**

Testing

Tested sanity/functional testing, MMLU, perf for below models

  1. DeepSeek-R1-Distill-Qwen-7B
  2. Qwen2.5-1.5B-Instruct

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: N/A
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Additional Information

Summary by CodeRabbit

  • New Features

    • Mixed-precision quantization (per-layer INT4/INT8) with new CLI flags --enable_mixed_quant and --int8_layers.
    • Explicit INT4/UINT4 and INT8/UINT8 support in quantization/export paths.
  • Improvements

    • Per-channel and per-block quantization handling with correct scale shaping, padding, and zero-point support.
    • Runtime now reports mixed-precision flag status.
  • Documentation

    • Example README updated to document the --enable_mixed_quant option.

@ynankani ynankani requested review from a team as code owners September 9, 2025 19:07
Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds per-weight mixed-precision (INT4/INT8) quantization support, centralizes quant math into shared utilities, extends Q/DQ and QDQ insertion to honor per-weight bit widths, and exposes new CLI options --enable_mixed_quant and --int8_layers, which are forwarded into the quantize flow.

Changes

Cohort / File(s) Summary
CLI: mixed-precision flags
examples/windows/onnx_ptq/genai_llm/quantize.py
Replaced --k_quant_mixed with --enable_mixed_quant (flag, default False) and added --int8_layers (str). Prints enable_mixed_quant in startup status and passes both options to quantize_int4.
INT4 quant core: mixed-precision plumbing
modelopt/onnx/quantization/int4.py
Removes local quant helpers in favor of imports from quant_utils; computes/threads precision_info when mixed quant enabled; forwards enable_mixed_quant and int8_layers into quant paths (RTN/AWQ/AWQ-lite); propagates precision info to DQ/QDQ insertion points.
QDQ utilities: dtype and per-weight overrides
modelopt/onnx/quantization/qdq_utils.py
Adds INT4/UINT4 dtype handling and bit-width maps; adds get_tensor_dtype; extends insert_dq_nodes and insert_qdq_nodes to accept precision_info and select per-weight num_bits/dtypes and zero-point handling.
Shared quantization utilities
modelopt/onnx/quantization/quant_utils.py
New block-aware helpers: update_block_size, _pad/_depad, find_scales, rtn, dq_tensor, quant_tensor, get_num_bits, update_scale_map_for_per_channel_nodes; adds CLIP_MIN. Implements per-block/per-channel quantization and optional zero-point support.
Graph precision mapping
modelopt/onnx/quantization/graph_utils.py
New helpers to find quantizable weights, validate/parse int8_layers, build per-layer precision mapping (get_layer_precision_mapping) and get_precision_info to return precision_info when enable_mixed_quant is set.
Docs
examples/windows/onnx_ptq/genai_llm/README.md
Appends description of --enable_mixed_quant flag to example README (text note with minor typo preserved).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as quantize.py (CLI)
  participant INT4 as modelopt.onnx.quantization.int4.quantize
  participant MAP as graph_utils.get_precision_info
  participant PIPE as RTN/AWQ/AWQ-lite
  participant QUTIL as quant_utils
  participant QDQ as qdq_utils

  CLI->>INT4: quantize(..., enable_mixed_quant, int8_layers)
  alt enable_mixed_quant == true
    INT4->>MAP: get_precision_info(model, int8_layers)
    MAP-->>INT4: precision_info (weight -> 4|8)
  else
    INT4-->>INT4: precision_info = None
  end

  INT4->>PIPE: run quant path(..., precision_info)
  PIPE->>QUTIL: find_scales / quant_tensor / rtn(..., num_bits via precision_info)
  QUTIL-->>PIPE: quantized weights, scales, zero-points

  PIPE->>QDQ: insert_dq_nodes / insert_qdq_nodes(..., precision_info)
  QDQ-->>PIPE: nodes with per-weight dtype (INT4/INT8) and zero-point handling

  PIPE-->>INT4: updated ONNX model
  INT4-->>CLI: write quantized model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble bits in fields of four,
Then hop to eight when layers implore.
I map each weight, pick bits with care,
Place Q/DQ nodes here and there.
Scales snug, zero-point true—quant dreams split anew. 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ynankani/mixed_precision-int4-int8

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Pre-merge checks

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title "[5506930]Add support in ModelOpt for generating mixed-precision (INT4…" accurately captures the primary change — adding mixed-precision INT4/INT8 support to ModelOpt — and aligns with the PR objectives and code changes that introduce per-weight precision, CLI flags, and layer mapping; it is concise and focused, though it includes a numeric prefix and an ellipsis that are minor noise.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
modelopt/onnx/quantization/int4.py (6)

579-587: Fix: pass a GraphSurgeon graph to _quantize_gather_nodes in AWQ-clip.

_quantize_gather_nodes iterates graph.nodes (GraphSurgeon), but here graph is an ONNX GraphProto (augmented_model.graph). This will throw at runtime.

Apply:

-        gather_w_map, gather_s_map, _ = _quantize_gather_nodes(
-            graph,
+        gather_w_map, gather_s_map, _ = _quantize_gather_nodes(
+            graph_gs,
             nodes_to_exclude,
             gather_quantize_axis,
             gather_block_size,
             use_zero_point=False,
             dq_only=True,
         )

1327-1334: Per-channel attribute missing for Gather DQ nodes in AWQ-lite.

Mirror the per-channel handling for Gather weights too.

         qdq.insert_dq_nodes(
             graph_gs,
             gather_s_map,
             quantized_weights=gather_w_map,
             attributes=gather_dq_node_attributes,
             zero_points=gather_zp_map if use_zero_point else None,
-            precision_info=precision_info,
+            precision_info=precision_info,
+            is_per_channel=is_per_channel,
         )

159-166: Avoid mutable default list for nodes_to_exclude in RTN API.

-def quantize_rtn(
+def quantize_rtn(
     onnx_model: onnx.ModelProto,
     block_size: int,
     dq_only: bool = False,
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
     precision_info: dict[str, int] | None = None,
     **kwargs: Any,
 ) -> onnx.ModelProto:
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])

457-468: Avoid mutable default list for nodes_to_exclude in AWQ-clip API.

-def _quantize_awq_clip(
+def _quantize_awq_clip(
     onnx_model: onnx.ModelProto,
@@
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])

1000-1014: Avoid mutable default list for nodes_to_exclude in AWQ-lite API.

-def _quantize_awq_lite(
+def _quantize_awq_lite(
@@
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])

1470-1484: Avoid mutable default list for nodes_to_exclude; add params doc for mixed precision.

-def quantize(
+def quantize(
@@
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
@@
-    logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}")
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}")

Optional: Update the docstring to mention k_quant_mixed and int8_layers behavior.

🧹 Nitpick comments (8)
modelopt/onnx/quantization/quant_utils.py (3)

169-171: Consider inlining this simple calculation.

This helper function is only used once (in _pad) and performs a trivial calculation. Consider inlining it directly where used to reduce unnecessary abstraction.

-def _next_block_size_multiple(x: float, block_size: int) -> float:
-    return math.ceil(x / block_size) * block_size
-

 def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray:
     """Pads `w` to next largest multiple of block_size, on quantize_axis."""
     assert quantize_axis <= len(w.shape), (
         f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}"
     )
     
     if w.shape[quantize_axis] % block_size == 0:
         return w
     
-    pad_width = (
-        _next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis]
-    )
+    pad_width = math.ceil(w.shape[quantize_axis] / block_size) * block_size - w.shape[quantize_axis]

190-201: Consider using numpy slicing syntax for all axes.

The function handles only 2D arrays but could be made more general and consistent using numpy's ellipsis syntax.

 def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarray:
     """Depad quantize_axis to original shape."""
     if w.shape == orig_shape:
         return w
-    ans = None
-    if quantize_axis == 0:
-        ans = w[0 : orig_shape[0], ...]
-    elif quantize_axis == 1:
-        ans = w[..., 0 : orig_shape[1]]
-    else:
-        raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array")
-    return ans
+    
+    if quantize_axis not in [0, 1]:
+        raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array")
+    
+    slices = [slice(None)] * len(w.shape)
+    slices[quantize_axis] = slice(0, orig_shape[quantize_axis])
+    return w[tuple(slices)]

323-338: Add docstring details about return values.

The function's docstring should specify what the three return values represent.

 def quant_tensor(
     w: np.ndarray,
     block_size: int,
     quantize_axis: int = 0,
     alpha: float = 1.0,
     use_zero_point: bool = False,
     precision_info: dict[str, int] | None = None,
     name: str | None = None,
 ):
-    """Quantize a tensor using alpha etc. and return the quantized tensor."""
+    """Quantize a tensor using alpha etc. and return the quantized tensor.
+    
+    Returns:
+        tuple: A tuple containing:
+            - wq: The quantized weight tensor (np.ndarray)
+            - scale: The scale factors used for quantization (np.ndarray)
+            - zp: The zero-point values (np.ndarray or None if not using zero-point)
+    """
     scale, zp = find_scales(
         w, block_size, quantize_axis, alpha, use_zero_point, precision_info, name
     )
     wq = rtn(w, scale, block_size, quantize_axis, zp, precision_info, name)
     return wq, scale, zp
examples/windows/onnx_ptq/genai_llm/quantize.py (1)

599-613: Consider validating the int8_layers pattern format.

The int8_layers parameter accepts a comma-separated string of layer patterns, but there's no validation of the format. Consider adding basic validation to catch user errors early.

 parser.add_argument(
     "--int8_layers",
     type=str,
     default="",
     help=(
         "Comma-separated list of layer patterns to quantize to INT8 instead of INT4."
         "Example: 'layers.0,layers.1,lm_head'"
     ),
 )
+
+def validate_int8_layers(layers_str: str) -> bool:
+    """Validate the format of int8_layers string."""
+    if not layers_str:
+        return True
+    # Basic validation: check for valid characters and structure
+    import re
+    pattern = r'^[a-zA-Z0-9_.,\-]+$'
+    return bool(re.match(pattern, layers_str))

Then add validation after parsing:

if args.int8_layers and not validate_int8_layers(args.int8_layers):
    parser.error("Invalid format for --int8_layers. Use comma-separated layer patterns.")
modelopt/onnx/quantization/qdq_utils.py (3)

310-322: Add input validation for the attributes parameter.

The function modifies the attributes dictionary in place, which could cause unexpected side effects if the same dictionary is reused.

 def update_attributes(attrib: dict[str, Any] | None = None, is_per_channel: bool = False):
     """Update attribute dictionary for quantization nodes.
     
     If per-channel quantization is enabled, sets the 'axis' attribute to 1 and removes
     the 'block_size' attribute if present.
+    
+    Args:
+        attrib: Attribute dictionary to update (will be modified in place if not None)
+        is_per_channel: Whether to enable per-channel quantization
+    
+    Returns:
+        The updated attribute dictionary
     """
     if is_per_channel:
         if attrib is not None:
             attrib["axis"] = 1
             if "block_size" in attrib:
                 attrib.pop("block_size")
     return attrib

354-369: Consider extracting tensor dtype selection logic to a helper function.

The tensor dtype selection logic is duplicated in both insert_dq_nodes and insert_qdq_nodes. Consider extracting it to reduce duplication and improve maintainability.

+def get_tensor_dtype(precision_info: dict[str, int] | None, name: str, has_zero_point: bool) -> int:
+    """Get the appropriate tensor dtype based on precision info and zero point presence.
+    
+    Args:
+        precision_info: Dictionary mapping tensor names to bit widths
+        name: Name of the tensor
+        has_zero_point: Whether the tensor has a zero point
+    
+    Returns:
+        ONNX tensor data type constant
+    """
+    if precision_info and name in precision_info:
+        bit_width = precision_info[name]
+        if has_zero_point:
+            dtype_str = onnx_bit_dtype_unsigned_map[bit_width]
+        else:
+            dtype_str = onnx_bit_dtype_signed_map[bit_width]
+        return onnx_dtype_map[dtype_str]
+    else:
+        # Default to 4-bit
+        return onnx.TensorProto.INT4 if not has_zero_point else onnx.TensorProto.UINT4

 def _insert_helper(
     name: str,
     wq: np.ndarray,
     scale: np.ndarray,
     dq_nodes: dict[str, gs.Node],
     zp: np.ndarray,
     attrs: dict[str, Any] | None = None,
     precision_info: dict[str, int] | None = None,
     is_per_channel: bool = False,
 ):
     attrib = dict(attrs) if attrs is not None else None
-    if precision_info and name in precision_info:
-        tensor_dtype = (
-            onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]]
-            if zp is None
-            else onnx_dtype_map[onnx_bit_dtype_unsigned_map[precision_info[name]]]
-        )
-        # do per-channel quantization for int8 as no support for int8 block-wise dq node
-        if precision_info[name] == 8:
-            # reshape scale to be per-channel
-            scale = scale.reshape(-1)
-            attrib = update_attributes(attrib, True)
-    else:
-        tensor_dtype = onnx.TensorProto.INT4 if zp is None else onnx.TensorProto.UINT4
+    tensor_dtype = get_tensor_dtype(precision_info, name, zp is not None)
+    
+    # do per-channel quantization for int8 as no support for int8 block-wise dq node
+    if precision_info and name in precision_info and precision_info[name] == 8:
+        # reshape scale to be per-channel
+        scale = scale.reshape(-1)
+        attrib = update_attributes(attrib, True)

361-364: Validate per‐channel scale size before flattening. Although scale.reshape(-1) preserves all elements, add an explicit check (e.g. assert scale.size == wq.shape[1]) before calling reshape to guard against mismatches when quantizing along axis 1.

modelopt/onnx/quantization/int4.py (1)

620-624: Use logger.warning instead of deprecated logger.warn.

-        logger.warn("Augmented ONNX model or external data file was not found")
+        logger.warning("Augmented ONNX model or external data file was not found")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d6d2e75 and caa9965.

📒 Files selected for processing (4)
  • examples/windows/onnx_ptq/genai_llm/quantize.py (3 hunks)
  • modelopt/onnx/quantization/int4.py (29 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (8 hunks)
  • modelopt/onnx/quantization/quant_utils.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/quant_utils.py (5)
  • _pad (173-187)
  • dq_tensor (297-320)
  • find_scales (204-252)
  • quant_tensor (323-337)
  • rtn (255-294)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (324-406)
  • insert_qdq_nodes (409-463)
modelopt/onnx/quantization/graph_utils.py (1)
  • expand_node_names_from_patterns (628-639)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • num_bits (180-182)
  • num_bits (185-187)
  • axis (279-281)
  • axis (284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (10)
modelopt/onnx/quantization/quant_utils.py (3)

160-167: LGTM! Clean implementation of bit-width selection logic.

The function provides a clear and simple way to get the number of bits for quantization based on the precision info dictionary, with a sensible default of 4 bits.


173-188: LGTM! Proper padding implementation with clear assertions.

The padding logic correctly handles different axes and includes appropriate assertions for validation.


297-321: LGTM! Clear dequantization implementation.

The dequantization logic correctly handles both signed and unsigned cases with proper padding and depadding.

examples/windows/onnx_ptq/genai_llm/quantize.py (2)

368-369: LGTM! Status output correctly includes the new k_quant_mixed flag.

The logging statement properly shows the new mixed-precision quantization option.


438-439: LGTM! Proper propagation of mixed-precision parameters.

The new parameters are correctly passed to the quantize_int4 function.

modelopt/onnx/quantization/qdq_utils.py (2)

57-61: LGTM! Proper extension of dtype mappings for 4-bit support.

The addition of INT4 and UINT4 to the dtype maps enables proper 4-bit quantization support.


432-440: LGTM! Consistent handling of per-weight precision in QDQ nodes.

The implementation correctly mirrors the DQ node logic for handling per-weight precision information.

modelopt/onnx/quantization/int4.py (3)

206-209: LGTM: precision_info is correctly threaded in RTN path.

Scale computation, quantization, and Q/DQ insertion honor per-weight precision and per-channel when applicable.

Also applies to: 240-243, 248-255, 259-265


315-322: LGTM: precision_info propagation in AWQ-clip compute path.

The clip search and final quantization use per-weight precision consistently.

Also applies to: 362-366, 526-528, 552-555


705-707: LGTM: precision_info propagation in AWQ-lite compute path.

Both per-node and per-subgraph scale searches respect mixed precision; quantization/dequantization calls pass the context correctly.

Also applies to: 762-771, 782-783, 882-884, 939-952, 1112-1131, 1182-1188

Comment on lines 1311 to 1273
dq_node_attributes = {"axis": 0, "block_size": block_size}

qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
precision_info=precision_info,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Per-channel attribute not propagated in AWQ-lite.

When block_size == -1, you set per-channel in RTN and AWQ-clip, but not here. TRT plugins require this attribute for correct dequantization.

Apply:

     dq_node_attributes = {"axis": 0, "block_size": block_size}
+    is_per_channel = block_size == -1

     qdq.insert_dq_nodes(
         graph_gs,
         scales,
         quantized_weights=gemm_weights_quantized,
         attributes=dq_node_attributes,
         zero_points=zero_points if use_zero_point else None,
-        precision_info=precision_info,
+        precision_info=precision_info,
+        is_per_channel=is_per_channel,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dq_node_attributes = {"axis": 0, "block_size": block_size}
qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
precision_info=precision_info,
)
dq_node_attributes = {"axis": 0, "block_size": block_size}
is_per_channel = block_size == -1
qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
precision_info=precision_info,
is_per_channel=is_per_channel,
)
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1311-1320, the
dq_node_attributes only contains "axis" and "block_size" for AWQ-lite; when
block_size == -1 you must also set the per-channel flag (as done for RTN and
AWQ-clip) so TRT plugins get correct dequantization info. Modify the
construction of dq_node_attributes to include "per_channel": True when
block_size == -1 (otherwise keep it False or omit), then pass that attributes
dict into qdq.insert_dq_nodes so the per-channel attribute is propagated.

Comment on lines 1357 to 1377
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
"""Check if layer should be quantized to INT8.
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.
This allows us to correctly identify which weights should be quantized to INT8.
"""
if not int8_layers:
return False
normalized_patterns = []
for pattern in int8_layers:
p = pattern.lstrip("/")
p = p.replace("/", ".")
normalized_patterns.append(p)
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid substring false-positives when matching INT8 layers.

in substring matching can misclassify e.g., layers.1 matches layers.11. Use token-exact matching.

-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
+def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
@@
-    normalized_patterns = []
-    for pattern in int8_layers:
-        p = pattern.lstrip("/")
-        p = p.replace("/", ".")
-        normalized_patterns.append(p)
-    return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)
+    # Normalize both to dot-delimited tokens and require exact token sequence match.
+    def tokens(s: str) -> list[str]:
+        return s.lstrip("/").replace("/", ".").split(".")
+    hay = tokens(layer_name)
+    for pat in int8_layers:
+        needle = tokens(pat)
+        n, m = len(hay), len(needle)
+        for i in range(0, n - m + 1):
+            if hay[i : i + m] == needle:
+                return True
+    return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
"""Check if layer should be quantized to INT8.
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.
This allows us to correctly identify which weights should be quantized to INT8.
"""
if not int8_layers:
return False
normalized_patterns = []
for pattern in int8_layers:
p = pattern.lstrip("/")
p = p.replace("/", ".")
normalized_patterns.append(p)
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
"""Check if layer should be quantized to INT8.
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.
This allows us to correctly identify which weights should be quantized to INT8.
"""
if not int8_layers:
return False
# Normalize both to dot-delimited tokens and require exact token sequence match.
def tokens(s: str) -> list[str]:
return s.lstrip("/").replace("/", ".").split(".")
hay = tokens(layer_name)
for pat in int8_layers:
needle = tokens(pat)
n, m = len(hay), len(needle)
for i in range(0, n - m + 1):
if hay[i : i + m] == needle:
return True
return False

Comment on lines 1379 to 1387
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
graph = onnx_model.graph

nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Don’t use a mutable list as a default argument.

nodes_to_exclude: list[str] | None = [r"/lm_head"] risks cross-call mutation.

-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
     onnx_model: onnx.ModelProto,
     int8_precision_pattern: str | None = None,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
 ):
-    graph = onnx_model.graph
+    graph = onnx_model.graph
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]

Optional: allow regex patterns in int8_precision_pattern to be expanded via expand_node_names_from_patterns for robustness.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
graph = onnx_model.graph
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = None,
):
graph = onnx_model.graph
nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1379 to 1387, the function
signature uses a mutable default list for nodes_to_exclude which can lead to
cross-call mutation; change the signature to use None as the default
(nodes_to_exclude: list[str] | None = None) and inside the function set
nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is None before calling
expand_node_names_from_patterns; additionally (optional) if
int8_precision_pattern should accept regexes, pass it through
expand_node_names_from_patterns or a similar helper so patterns are expanded
consistently before use.

Comment on lines 255 to 326
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
+ zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve numerical stability for edge cases.

When using zero-point quantization, division by s could cause issues if s contains values close to zero (even though clipped by CLIP_MIN). Consider adding a safety check.

 def rtn(
     w: np.ndarray,
     s: np.ndarray,
     block_size: int,
     quantize_axis: int = 0,
     zp: np.ndarray = None,
     precision_info: dict[str, int] | None = None,
     name: str | None = None,
 ) -> np.ndarray:
     """Quantizes `w` with scale factors `s` via Round-to-Nearest.
     
     Ties are broken by rounding to the nearest even number.
     """
     num_bits = get_num_bits(precision_info, name)
     # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
     # set block_size to the size of the quantize_axis dimension to do per-channel quantization
     if block_size == -1 or num_bits == 8:
         block_size = w.shape[quantize_axis]
     w_padded = _pad(w, block_size, quantize_axis)
     num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
+    
+    # Ensure scale values are not too small to avoid numerical instability
+    if np.any(np.abs(s) < CLIP_MIN):
+        raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}")
+    
     if zp is None:
         maxq = 2 ** (num_bits - 1) - 1
         minq = -(2 ** (num_bits - 1))
         w_padded = (
             np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
             .clip(minq, maxq)
             .astype(np.int8)
         )
     else:
         maxq = (2**num_bits) - 1
         minq = 0
         w_padded = (
             (
                 np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
                 + zp.repeat(num_blocks, axis=quantize_axis)
             )
             .clip(minq, maxq)
             .astype(np.int8)
         )
     return _depad(w_padded, w.shape, quantize_axis)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
+ zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
# Ensure scale values are not too small to avoid numerical instability
if np.any(np.abs(s) < CLIP_MIN):
raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}")
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)

)

parser.add_argument(
"--k_quant_mixed",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does k stands for in k_quant_mixed?

Copy link
Author

@ynankani ynankani Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k stands for groupwise/block-wise quantization. We are doing Int4 block-wise + few layers in Int8 per-channel quantization (currently we don't have Int8 block-wise DequantizeLinear node support on trt-rtx side) .

@ynankani
Copy link
Author

Need to make few more changes as per discussion with Vishal. Will request for review once the updated changes are ready.

…+INT8) ONNX models, refactored changes and handle comments

Signed-off-by: unknown <[email protected]>
@ynankani ynankani force-pushed the ynankani/mixed_precision-int4-int8 branch from caa9965 to b6a39be Compare September 12, 2025 18:22
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)

417-450: Zero-point dtype should be unsigned.

When creating Q/DQ, zp should use the unsigned dtype. Pass has_zero_point=True to get_tensor_dtype.

-        tensor_dtype = get_tensor_dtype(num_bits)
+        tensor_dtype = get_tensor_dtype(num_bits)
@@
-        zp_tensor = make_gs_zp(name, scale.shape, tensor_dtype)
+        zp_tensor = make_gs_zp(name, scale.shape, get_tensor_dtype(num_bits, has_zero_point=True))
♻️ Duplicate comments (1)
modelopt/onnx/quantization/quant_utils.py (1)

292-326: Numerical safety in rtn.

Add a guard to avoid tiny s after broadcast.

-    if zp is None:
+    # Safety: avoid division by tiny scales
+    if np.any(np.abs(s) < CLIP_MIN):
+        s = np.where(np.abs(s) < CLIP_MIN, CLIP_MIN, s)
+    if zp is None:
         maxq = 2 ** (num_bits - 1) - 1
🧹 Nitpick comments (7)
modelopt/onnx/quantization/graph_utils.py (2)

629-667: Type hint misuse and minor robustness.

  • cast("int", weight_tensor.data_type) uses a string literal as a type param; use cast(int, ...) or drop cast.
  • Consider guarding for Gemm transA (currently TODO) or explicitly documenting unsupported cases.
-        gemm_io_type = cast("int", weight_tensor.data_type)
+        gemm_io_type = cast(int, weight_tensor.data_type)

669-696: Token-exact matcher looks good. Add fast exit and docstring type.

Add -> bool and early return for long patterns.

-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
+def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
@@
-    for pat in int8_layers:
+    for pat in int8_layers:
         needle = tokens(pat)
         n, m = len(hay), len(needle)
+        if m == 0 or m > n:
+            continue
modelopt/onnx/quantization/quant_utils.py (2)

160-177: update_block_size: clarify contract for w=None.

Function dereferences w unconditionally when block_size==-1 or num_bits==8. Either assert or handle None.

-    if block_size is not None and (block_size == -1 or num_bits == 8):
-        return w.shape[quantize_axis]
+    if block_size is not None and (block_size == -1 or num_bits == 8):
+        assert w is not None, "update_block_size requires `w` when per-channel/INT8 is requested"
+        return w.shape[quantize_axis]

196-198: Type hint for _next_block_size_multiple.

Return type is int in practice; adjust for readability/tools.

-def _next_block_size_multiple(x: float, block_size: int) -> float:
+def _next_block_size_multiple(x: float, block_size: int) -> int:
modelopt/onnx/quantization/qdq_utils.py (1)

50-63: ONNX 4-bit enums: confirm availability.

onnx.TensorProto.INT4/UINT4 are relatively new; some environments may lack them. Add a fallback or guard.

Option: probe availability once and raise a clear error.

-onnx_dtype_map = {
+onnx_dtype_map = {
@@
-    "INT4": onnx.TensorProto.INT4,
-    "UINT4": onnx.TensorProto.UINT4,
+    "INT4": getattr(onnx.TensorProto, "INT4", None),
+    "UINT4": getattr(onnx.TensorProto, "UINT4", None),
 }
+
+# Validate availability early
+if onnx_dtype_map["INT4"] is None or onnx_dtype_map["UINT4"] is None:
+    logger.error("This ONNX build does not support 4-bit dtypes (INT4/UINT4).")
+    # Optionally: raise, or fall back to INT8 path depending on product decision

Would you like me to wire a clean fallback to INT8 when 4-bit enums are missing?

modelopt/onnx/quantization/int4.py (2)

221-227: Pattern-vs-name expansion mismatch risk.

nodes_to_exclude is already expanded to concrete names before calling get_precision_info, which expects patterns (it internally expands again). This is benign but redundant and can mis-handle special regex chars in names.

Two options:

  • Pass the original patterns into get_precision_info, or
  • Change get_precision_info to accept already-expanded names and skip expansion.

1369-1376: Doc/type mismatch for int8_layers.

Docstring says “comma-separated list” but default shows [] and the code expects a string or None. Clarify.

-                - **int8_layers** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
-                                              Default: [].
+                - **int8_layers** (str | None): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
+                                              Default: None.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between caa9965 and b6a39be.

📒 Files selected for processing (6)
  • examples/windows/onnx_ptq/genai_llm/README.md (1 hunks)
  • examples/windows/onnx_ptq/genai_llm/quantize.py (3 hunks)
  • modelopt/onnx/quantization/graph_utils.py (2 hunks)
  • modelopt/onnx/quantization/int4.py (27 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (8 hunks)
  • modelopt/onnx/quantization/quant_utils.py (3 hunks)
✅ Files skipped from review due to trivial changes (1)
  • examples/windows/onnx_ptq/genai_llm/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/windows/onnx_ptq/genai_llm/quantize.py
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
  • get_num_bits (179-193)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (2)
  • _find_quantizable_weights (629-666)
  • get_precision_info (814-843)
modelopt/onnx/quantization/quant_utils.py (8)
  • _pad (200-214)
  • dq_tensor (328-344)
  • find_scales (242-289)
  • get_num_bits (179-193)
  • quant_tensor (347-366)
  • rtn (292-325)
  • update_block_size (160-176)
  • update_scale_map_for_per_channel_nodes (231-239)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (327-399)
  • insert_qdq_nodes (402-456)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • num_bits (180-182)
  • num_bits (185-187)
  • axis (279-281)
  • axis (284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/onnx/quantization/int4.py (2)

562-571: Per-channel handling is correctly propagated.

Good use of update_scale_map_for_per_channel_nodes and attribute adjustment in qdq insertion to set axis=1 and drop block_size for INT8/per-channel. This aligns TRT plugin expectations.

Also applies to: 1264-1273


244-253: Gather path: mixed-precision propagation looks correct.

You compute per-weight num_bits and reuse update_block_size for gather. Nice reuse and consistency across RTN/AWQ paths.

Also applies to: 1253-1261

Comment on lines +698 to +707
def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re

pattern = r"^[a-zA-Z0-9_.,\-]$"
return bool(re.match(pattern, layers_str))

Copy link

@coderabbitai coderabbitai bot Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix regex validation to accept real layer patterns.

The current pattern r"^[a-zA-Z0-9_.,\-]$ only matches a single character and disallows /. This will reject valid inputs like /model/layers.13/attn/qkv_proj/MatMul.

Apply:

-def validate_int8_layers(layers_str: str) -> bool:
+def validate_int8_layers(layers_str: str) -> bool:
@@
-    pattern = r"^[a-zA-Z0-9_.,\-]$"
-    return bool(re.match(pattern, layers_str))
+    # Allow comma-separated list of path-like tokens
+    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
+    return bool(re.match(pattern, layers_str))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re
pattern = r"^[a-zA-Z0-9_.,\-]$"
return bool(re.match(pattern, layers_str))
def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re
# Allow comma-separated list of path-like tokens
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
return bool(re.match(pattern, layers_str))
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 698 to 707, the regex
only matches a single character and disallows "/" so valid layer paths are
rejected; update the validation to allow one-or-more characters and include "/"
(and keep letters, digits, underscore, dot, comma and hyphen), e.g. replace the
pattern with a character-class that includes "/" and a quantifier (and ensure
hyphen is either escaped or last), and use re.fullmatch instead of re.match so
the entire string is validated (e.g. pattern like "^[A-Za-z0-9_/.,-]+$" with
re.fullmatch).

Copy link
Contributor

@i-riyad i-riyad Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look into this suggestion

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

@i-riyad I've analyzed the suggestion and it's correct - there's a critical bug in the regex validation.

The CodeRabbit suggestion is absolutely correct. The current regex has two major flaws:

  1. Single character matching: The pattern ^[a-zA-Z0-9_.,\-]$ only matches exactly one character due to the absence of quantifiers
  2. Missing forward slash: Real ONNX node names like /model/layers.13/attn/qkv_proj/MatMul contain forward slashes which aren't allowed

The suggested fix properly addresses both issues by:

  • Adding + quantifier to match one or more characters
  • Including / in the character class
  • Supporting comma-separated values with optional whitespace

However, I'd recommend a small improvement to make the validation more robust:

def validate_int8_layers(layers_str: str) -> bool:
    """Validate the format of int8_layers string."""
    if not layers_str:
        return True
    # Allow comma-separated list of path-like tokens
    # Pattern allows: letters, digits, underscore, dot, slash, hyphen
    # Supports comma-separated values with optional whitespace
    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
    return bool(re.fullmatch(pattern, layers_str))

Using re.fullmatch() instead of re.match() provides clearer intent (matching the entire string) and is slightly more explicit, though functionally equivalent when using ^ and $ anchors.

The suggestion should definitely be implemented to fix this validation bug.


🏁 Script executed:

#!/bin/bash
# Let's examine the current validate_int8_layers function and test the regex behavior
cd /tmp
cat > test_regex.py << 'EOF'
import re

# Current implementation
def validate_int8_layers_current(layers_str: str) -> bool:
    if not layers_str:
        return True
    pattern = r"^[a-zA-Z0-9_.,\-]$"
    return bool(re.match(pattern, layers_str))

# Suggested implementation
def validate_int8_layers_suggested(layers_str: str) -> bool:
    if not layers_str:
        return True
    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
    return bool(re.match(pattern, layers_str))

# Test cases
test_cases = [
    "/model/layers.13/attn/qkv_proj/MatMul",
    "/model/layers.0/attn/v_proj/MatMul,/model/layers.1/mlp/down_proj/MatMul",
    "model.layers.13.attn.qkv_proj.MatMul.weight",
    "a",  # single character - should pass for current
    "",   # empty string - should pass for both
    "/invalid/path with spaces/MatMul",  # spaces in path - should fail
    "valid,path",  # simple comma-separated - should pass for suggested
]

print("Testing current vs suggested implementation:")
print("=" * 60)
for test in test_cases:
    current_result = validate_int8_layers_current(test)
    suggested_result = validate_int8_layers_suggested(test)
    print(f"Input: '{test}'")
    print(f"  Current:   {current_result}")
    print(f"  Suggested: {suggested_result}")
    print()
EOF

python test_regex.py

Length of output: 634


Perfect! The test results confirm the analysis. The current implementation is severely broken - it only accepts single characters or empty strings, making it useless for validating real layer patterns.

Key findings:

  • ✅ Current regex rejects ALL valid layer patterns like /model/layers.13/attn/qkv_proj/MatMul
  • ✅ Suggested fix correctly accepts valid patterns while rejecting invalid ones (like paths with spaces)
  • ✅ Both handle empty strings correctly
  • ✅ Suggested pattern properly supports comma-separated values

Recommendation: Implement the CodeRabbit suggestion immediately. This is a critical bug that makes the validation function completely non-functional for its intended purpose.

Here's the corrected implementation:

def validate_int8_layers(layers_str: str) -> bool:
    """Validate the format of int8_layers string."""
    if not layers_str:
        return True
    # Basic validation: check for valid characters and structure
    import re

    # Allow comma-separated list of path-like tokens
    pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
    return bool(re.match(pattern, layers_str))

The current regex would cause the mixed-precision quantization feature to fail silently or reject all user-provided layer specifications, which would be a significant usability issue.

Comment on lines +709 to +714
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid mutable default args and normalize exclusions.

Using a list default (nodes_to_exclude=[r"/lm_head"]) risks cross-call mutation. Also normalize nodes_to_exclude up-front.

Apply:

-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
     onnx_model: onnx.ModelProto,
     int8_precision_pattern: str | None = None,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
 ):
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)

Also applies to: 728-731

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 709-714 (and similarly
at 728-731), avoid using a mutable list as a default for nodes_to_exclude:
change the signature to use nodes_to_exclude: list[str] | None = None, then
inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is
None else list(nodes_to_exclude) to prevent shared-state mutations; immediately
normalize entries (e.g., strip/ensure raw strings or compiled regex as the code
expects) so downstream logic can assume a consistent list type and format.

Comment on lines +814 to +844
def get_precision_info(
onnx_model: onnx.ModelProto,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
**kwargs: Any,
):
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).
This function determines the quantization precision for each weight tensor in the ONNX model,
based on the provided configuration. If mixed quantization is enabled, it uses the layer
precision mapping; otherwise, it returns None.
Args:
onnx_model (onnx.ModelProto): The ONNX model to analyze.
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
**kwargs: Additional keyword arguments, such as:
- enable_mixed_quant (bool): Whether to enable mixed quantization.
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8.
Returns:
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
or None if mixed quantization is not enabled.
"""
precision_info = None
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
int8_layers = kwargs.get("int8_layers")
if enable_mixed_quant:
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude)
else:
precision_info = None
return precision_info

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Same mutable default issue in public API.

Apply same fix as above.

-def get_precision_info(
+def get_precision_info(
     onnx_model: onnx.ModelProto,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
     **kwargs: Any,
 ):
@@
-    enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    enable_mixed_quant = kwargs.get("enable_mixed_quant", False)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 814 to 844, the
function get_precision_info declares a mutable default parameter
nodes_to_exclude = [r"/lm_head"], which should be avoided; change the signature
to use nodes_to_exclude: list[str] | None = None (or Optional[list[str]] = None)
and inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude
is None, preserving the original default behavior while avoiding the mutable
default; update any type hints and callers if necessary.

Comment on lines +311 to +325
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
"""Get the appropriate tensor dtype based on precision info and zero point presence.
Args:
num_bits: Number of bits for quantization
has_zero_point: Whether the tensor has a zero point
Returns:
ONNX tensor data type constant
"""
if has_zero_point:
dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
else:
dtype_str = onnx_bit_dtype_signed_map[num_bits]
return onnx_dtype_map[dtype_str]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate num_bits and map lookups.

Guard against unexpected num_bits and derive dtype accordingly.

 def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
@@
-    if has_zero_point:
-        dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
-    else:
-        dtype_str = onnx_bit_dtype_signed_map[num_bits]
+    valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys()
+    if num_bits not in valid_bits:
+        raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
+    dtype_str = (
+        onnx_bit_dtype_unsigned_map[num_bits]
+        if has_zero_point
+        else onnx_bit_dtype_signed_map[num_bits]
+    )
     return onnx_dtype_map[dtype_str]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
"""Get the appropriate tensor dtype based on precision info and zero point presence.
Args:
num_bits: Number of bits for quantization
has_zero_point: Whether the tensor has a zero point
Returns:
ONNX tensor data type constant
"""
if has_zero_point:
dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
else:
dtype_str = onnx_bit_dtype_signed_map[num_bits]
return onnx_dtype_map[dtype_str]
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
"""Get the appropriate tensor dtype based on precision info and zero point presence.
Args:
num_bits: Number of bits for quantization
has_zero_point: Whether the tensor has a zero point
Returns:
ONNX tensor data type constant
"""
valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys()
if num_bits not in valid_bits:
raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
dtype_str = (
onnx_bit_dtype_unsigned_map[num_bits]
if has_zero_point
else onnx_bit_dtype_signed_map[num_bits]
)
return onnx_dtype_map[dtype_str]
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 311 to 325, the function
does unguarded lookups on num_bits and mapping dicts which can raise KeyError
for unexpected values; validate that num_bits is one of the supported keys
(e.g., check presence in onnx_bit_dtype_unsigned_map and
onnx_bit_dtype_signed_map) before using it, raise a clear ValueError if
unsupported, and guard the final onnx_dtype_map lookup similarly (or provide a
default/raise with context) so the function always fails fast with a clear
message instead of crashing with a KeyError.

Comment on lines +200 to +215
def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray:
"""Pads `w` to next largest multiple of block_size, on quantize_axis."""
assert quantize_axis <= len(w.shape), (
f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}"
)

if w.shape[quantize_axis] % block_size == 0:
return w

pad_width = (
_next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis]
)
pads = [(0, 0) for _ in range(len(w.shape))]
pads[quantize_axis] = (0, pad_width)
return np.pad(w, pads, mode="constant", constant_values=0)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Out-of-bounds axis check in _pad (can crash).

quantize_axis <= len(w.shape) allows == which is invalid; indexing will fail.

Apply:

-    assert quantize_axis <= len(w.shape), (
+    assert 0 <= quantize_axis < len(w.shape), (
         f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}"
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray:
"""Pads `w` to next largest multiple of block_size, on quantize_axis."""
assert quantize_axis <= len(w.shape), (
f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}"
)
if w.shape[quantize_axis] % block_size == 0:
return w
pad_width = (
_next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis]
)
pads = [(0, 0) for _ in range(len(w.shape))]
pads[quantize_axis] = (0, pad_width)
return np.pad(w, pads, mode="constant", constant_values=0)
def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray:
"""Pads `w` to next largest multiple of block_size, on quantize_axis."""
assert 0 <= quantize_axis < len(w.shape), (
f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}"
)
if w.shape[quantize_axis] % block_size == 0:
return w
pad_width = (
_next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis]
)
pads = [(0, 0) for _ in range(len(w.shape))]
pads[quantize_axis] = (0, pad_width)
return np.pad(w, pads, mode="constant", constant_values=0)
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/quant_utils.py around lines 200 to 215, the
assertion `quantize_axis <= len(w.shape)` incorrectly allows equality and can
cause an out-of-bounds index; replace it with a proper bounds check such as
`assert 0 <= quantize_axis < len(w.shape)` (or if negative axes should be
allowed, use `assert -len(w.shape) <= quantize_axis < len(w.shape)`) and keep
the same error message/context so indexing into w.shape cannot fail.

Comment on lines +242 to +264
def find_scales(
w: np.ndarray,
block_size: int,
quantize_axis: int = 0,
alpha: float = 1.0,
use_zero_point: bool = False,
num_bits: int = 4,
):
"""Find scale factors for `w` via `s = max(w.block(block_size)) / 7`."""
w = _pad(w, block_size, quantize_axis)
if quantize_axis == 0:
w = w.T

s_last_dim = w.shape[-1] // block_size
s_shape = list(w.shape)
s_shape[-1] = s_last_dim
z = None
if not use_zero_point:
scale = 2 ** (num_bits - 1) - 1
w_amax = np.abs(w.reshape(-1, block_size)).max(axis=-1)
s = (w_amax * alpha) / scale
s = s.reshape(s_shape)
else:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Prevent division by zero: clamp scales in both branches.

When all-zeros blocks occur, s can be 0 in the no-zero-point path; later division in rtn will hit inf/NaN.

Apply:

-        s = (w_amax * alpha) / scale
+        s = (w_amax * alpha) / scale
+        s = np.clip(s, CLIP_MIN, None)
@@
-        s = s.reshape(s_shape)
+        s = np.clip(s, CLIP_MIN, None).reshape(s_shape)

Also applies to: 269-289

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/quant_utils.py around lines 242-264 (and also
apply same fix to lines 269-289), the computed scale array s can be zero for
all-zero blocks which will cause division-by-zero later; clamp s to a small
positive epsilon after computing it in both the no-zero-point and zero-point
branches (e.g. s = np.maximum(s, 1e-8) or set zero entries to eps) so downstream
divisions never produce inf/NaN, preserving the original shapes and dtype.

Copy link

codecov bot commented Sep 12, 2025

Codecov Report

❌ Patch coverage is 55.82329% with 110 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.63%. Comparing base (76e8ce2) to head (b6a39be).

Files with missing lines Patch % Lines
modelopt/onnx/quantization/graph_utils.py 26.80% 71 Missing ⚠️
modelopt/onnx/quantization/quant_utils.py 70.65% 27 Missing ⚠️
modelopt/onnx/quantization/int4.py 82.50% 7 Missing ⚠️
modelopt/onnx/quantization/qdq_utils.py 75.00% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #310      +/-   ##
==========================================
- Coverage   73.88%   73.63%   -0.25%     
==========================================
  Files         172      172              
  Lines       17444    17587     +143     
==========================================
+ Hits        12889    12951      +62     
- Misses       4555     4636      +81     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ynankani ynankani requested a review from i-riyad September 16, 2025 03:59
Comment on lines +600 to +606
"--enable_mixed_quant",
default=False,
action="store_true",
help="True when we want to use mixed quantization",
)
parser.add_argument(
"--int8_layers",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use just assume mixed_quant enabled if --int8_layers is non-empty? And remove the --enable_mixed_quant option?

Copy link
Author

@ynankani ynankani Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if --int8_layers is specified we will select only those layers to be quantized in int8 which match the patter in int8_layers, else if only --enable_mixed_quant is specified we hardcode select few important layers similar to what some other quantization tools like model builder/llama.cpp are doing.
example:
if "python quantize.py ... -int8_layer="layer.0" --enable_mixed_quant" => all layer.0 will be quantized to int8
else "python quantize.py ... --enable_mixed_quant" => quantize to int8 the first 1/8, last 1/8 and every 3rd layer for below attn and ffn layers.
/model/layers.{i}/attn/qkv_proj/MatMul
/model/layers.{i}/attn/v_proj/MatMul
/model/layers.{i}/mlp/down_proj/MatMul

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, make sense now.

@ynankani ynankani requested a review from i-riyad September 18, 2025 03:49
default="",
help=(
"Comma-separated list of layer patterns to quantize to INT8 instead of INT4."
"Example: 'layers.0,layers.1,lm_head'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Overrides default mixed quant strategy. Example: 'layers.0,lm_head

"--enable_mixed_quant",
default=False,
action="store_true",
help="True when we want to use mixed quantization",
Copy link
Contributor

@i-riyad i-riyad Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: help="Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd layer quantized to INT8; others to INT4.",

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants