Skip to content

Conversation

ynankani
Copy link
Contributor

@ynankani ynankani commented Sep 9, 2025

…+INT8) ONNX models

What does this PR do?

Type of change: new feature

Overview: Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models

Usage

python quantize.py --model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --onnx_path=C:\Users\Deepseek_qwen_7b\output\DeepSeek-R1-Distill-Qwen-7B_fp16\model\model.onnx --output_path=C:\Users\Deepseek_qwen_7b\output\Deepseel_r1_distill_qwen_7b_mixed_awq\model.onnx --calib_size=32 --algo=awq_lite --dataset=cnn --calibration_eps=NvTensorRtRtx --no_position_ids **--enable_mixed_quant**

Testing

Tested sanity/functional testing, MMLU, perf for below models

  1. DeepSeek-R1-Distill-Qwen-7B
  2. Qwen2.5-1.5B-Instruct

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: N/A
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added CLI flags --enable_mixed_quant and --layers_8bit; runtime now reports their values and enforces mixed-precision when layers_8bit is set.
    • Per-layer 8-bit overrides supported for targeted layers.
  • Improvements

    • Enhanced per-channel/per-block quantization for better accuracy and dtype handling (including 4-bit/8-bit outputs).
    • Heuristic and pattern-based per-layer precision mapping for mixed precision.
  • Documentation

    • README updated with the new CLI flags.

@ynankani ynankani requested review from a team as code owners September 9, 2025 19:07
Copy link

coderabbitai bot commented Sep 9, 2025

Walkthrough

Adds mixed-precision per-weight quantization (INT4/INT8), centralizes block-aware quant math into shared utilities, extends Q/DQ and QDQ insertion to honor per-weight bit widths, adds CLI flags --enable_mixed_quant and --layers_8bit, and threads precision mappings through the quantize pipeline.

Changes

Cohort / File(s) Summary
CLI: mixed-precision flags
examples/windows/onnx_ptq/genai_llm/quantize.py
Adds --enable_mixed_quant (flag, default False) and --layers_8bit (str, default ""); when --layers_8bit provided, enable_mixed_quant is forced true; prints both values at startup and forwards them to quantize_int4.
INT4 quant core: mixed-precision plumbing
modelopt/onnx/quantization/int4.py
Replaces many local quant helpers with imports from quant_utils and qdq_utils; threads precision_info through RTN/AWQ/AWQ-lite flows; updates calls so quantize(...) / quantize_int4 accept and propagate enable_mixed_quant and layers_8bit; per-weight num_bits and updated block_size are used for scales, quant, and DQ insertion.
QDQ utilities: dtype & per-weight overrides
modelopt/onnx/quantization/qdq_utils.py
Adds INT4/UINT4 dtype mappings and get_tensor_dtype; insert_dq_nodes and insert_qdq_nodes accept precision_info and compute per-weight num_bits/dtypes and zero-point handling; adds per-channel attrs/validation helpers.
Shared quant math & helpers
modelopt/onnx/quantization/quant_utils.py
Adds block-aware quant utilities (CLIP_MIN, _pad/_depad, _next_block_size_multiple, update_block_size, get_num_bits, reshape_scales_for_per_channel_nodes, find_scales, rtn, dq_tensor, quant_tensor) supporting per-channel/per-block quant, optional zero-point, and returning scales/zero-points.
Graph precision mapping
modelopt/onnx/quantization/graph_utils.py
Adds _find_int4_quantizable_weights, should_quantize_to_8bit, validate_8bit_layers, get_layer_precision_mapping, and get_precision_info to validate/parse layers_8bit and produce per-layer precision mappings used when mixed quant is enabled.
Docs
examples/windows/onnx_ptq/genai_llm/README.md
Appends README rows documenting --enable_mixed_quant and --layers_8bit flags (text added; original minor typo preserved).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as quantize.py (CLI)
  participant INT4 as modelopt.onnx.quantization.int4.quantize
  participant MAP as graph_utils.get_precision_info
  participant PIPE as RTN/AWQ/AWQ-lite
  participant QUTIL as quant_utils
  participant QDQ as qdq_utils

  CLI->>INT4: quantize(..., enable_mixed_quant, layers_8bit)
  alt layers_8bit provided
    INT4-->>INT4: force enable_mixed_quant = true
  end
  alt enable_mixed_quant == true
    INT4->>MAP: get_precision_info(model, precision_pattern_8bit=layers_8bit)
    MAP-->>INT4: precision_info (weight -> 4|8)
  else
    INT4-->>INT4: precision_info = None
  end

  INT4->>PIPE: run quant path(..., precision_info, enable_mixed_quant)
  PIPE->>QUTIL: find_scales / quant_tensor / rtn(..., num_bits via precision_info)
  QUTIL-->>PIPE: quantized weights, scales, zero-points

  PIPE->>QDQ: insert_dq_nodes / insert_qdq_nodes(..., precision_info)
  QDQ-->>PIPE: nodes with per-weight dtype (INT4/INT8) and zero-point handling

  PIPE-->>INT4: updated ONNX model
  INT4-->>CLI: write quantized model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I nibble bits in fields of four,
Then hop to eight when layers implore.
I map each weight, choose num_bits with care,
Stitch scales and DQ nodes everywhere.
Hooray — mixed-precision hops through the air. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title concisely and accurately summarizes the primary change—adding mixed-precision (INT4+INT8) support to ModelOpt—and aligns with the PR objectives and the modified files.
Docstring Coverage ✅ Passed Docstring coverage is 80.95% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ynankani/mixed_precision-int4-int8

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
modelopt/onnx/quantization/int4.py (6)

579-587: Fix: pass a GraphSurgeon graph to _quantize_gather_nodes in AWQ-clip.

_quantize_gather_nodes iterates graph.nodes (GraphSurgeon), but here graph is an ONNX GraphProto (augmented_model.graph). This will throw at runtime.

Apply:

-        gather_w_map, gather_s_map, _ = _quantize_gather_nodes(
-            graph,
+        gather_w_map, gather_s_map, _ = _quantize_gather_nodes(
+            graph_gs,
             nodes_to_exclude,
             gather_quantize_axis,
             gather_block_size,
             use_zero_point=False,
             dq_only=True,
         )

1327-1334: Per-channel attribute missing for Gather DQ nodes in AWQ-lite.

Mirror the per-channel handling for Gather weights too.

         qdq.insert_dq_nodes(
             graph_gs,
             gather_s_map,
             quantized_weights=gather_w_map,
             attributes=gather_dq_node_attributes,
             zero_points=gather_zp_map if use_zero_point else None,
-            precision_info=precision_info,
+            precision_info=precision_info,
+            is_per_channel=is_per_channel,
         )

159-166: Avoid mutable default list for nodes_to_exclude in RTN API.

-def quantize_rtn(
+def quantize_rtn(
     onnx_model: onnx.ModelProto,
     block_size: int,
     dq_only: bool = False,
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
     precision_info: dict[str, int] | None = None,
     **kwargs: Any,
 ) -> onnx.ModelProto:
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])

457-468: Avoid mutable default list for nodes_to_exclude in AWQ-clip API.

-def _quantize_awq_clip(
+def _quantize_awq_clip(
     onnx_model: onnx.ModelProto,
@@
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])

1000-1014: Avoid mutable default list for nodes_to_exclude in AWQ-lite API.

-def _quantize_awq_lite(
+def _quantize_awq_lite(
@@
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])

1470-1484: Avoid mutable default list for nodes_to_exclude; add params doc for mixed precision.

-def quantize(
+def quantize(
@@
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
@@
-    logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}")
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}")

Optional: Update the docstring to mention k_quant_mixed and int8_layers behavior.

🧹 Nitpick comments (8)
modelopt/onnx/quantization/quant_utils.py (3)

169-171: Consider inlining this simple calculation.

This helper function is only used once (in _pad) and performs a trivial calculation. Consider inlining it directly where used to reduce unnecessary abstraction.

-def _next_block_size_multiple(x: float, block_size: int) -> float:
-    return math.ceil(x / block_size) * block_size
-

 def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray:
     """Pads `w` to next largest multiple of block_size, on quantize_axis."""
     assert quantize_axis <= len(w.shape), (
         f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}"
     )
     
     if w.shape[quantize_axis] % block_size == 0:
         return w
     
-    pad_width = (
-        _next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis]
-    )
+    pad_width = math.ceil(w.shape[quantize_axis] / block_size) * block_size - w.shape[quantize_axis]

190-201: Consider using numpy slicing syntax for all axes.

The function handles only 2D arrays but could be made more general and consistent using numpy's ellipsis syntax.

 def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarray:
     """Depad quantize_axis to original shape."""
     if w.shape == orig_shape:
         return w
-    ans = None
-    if quantize_axis == 0:
-        ans = w[0 : orig_shape[0], ...]
-    elif quantize_axis == 1:
-        ans = w[..., 0 : orig_shape[1]]
-    else:
-        raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array")
-    return ans
+    
+    if quantize_axis not in [0, 1]:
+        raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array")
+    
+    slices = [slice(None)] * len(w.shape)
+    slices[quantize_axis] = slice(0, orig_shape[quantize_axis])
+    return w[tuple(slices)]

323-338: Add docstring details about return values.

The function's docstring should specify what the three return values represent.

 def quant_tensor(
     w: np.ndarray,
     block_size: int,
     quantize_axis: int = 0,
     alpha: float = 1.0,
     use_zero_point: bool = False,
     precision_info: dict[str, int] | None = None,
     name: str | None = None,
 ):
-    """Quantize a tensor using alpha etc. and return the quantized tensor."""
+    """Quantize a tensor using alpha etc. and return the quantized tensor.
+    
+    Returns:
+        tuple: A tuple containing:
+            - wq: The quantized weight tensor (np.ndarray)
+            - scale: The scale factors used for quantization (np.ndarray)
+            - zp: The zero-point values (np.ndarray or None if not using zero-point)
+    """
     scale, zp = find_scales(
         w, block_size, quantize_axis, alpha, use_zero_point, precision_info, name
     )
     wq = rtn(w, scale, block_size, quantize_axis, zp, precision_info, name)
     return wq, scale, zp
examples/windows/onnx_ptq/genai_llm/quantize.py (1)

599-613: Consider validating the int8_layers pattern format.

The int8_layers parameter accepts a comma-separated string of layer patterns, but there's no validation of the format. Consider adding basic validation to catch user errors early.

 parser.add_argument(
     "--int8_layers",
     type=str,
     default="",
     help=(
         "Comma-separated list of layer patterns to quantize to INT8 instead of INT4."
         "Example: 'layers.0,layers.1,lm_head'"
     ),
 )
+
+def validate_int8_layers(layers_str: str) -> bool:
+    """Validate the format of int8_layers string."""
+    if not layers_str:
+        return True
+    # Basic validation: check for valid characters and structure
+    import re
+    pattern = r'^[a-zA-Z0-9_.,\-]+$'
+    return bool(re.match(pattern, layers_str))

Then add validation after parsing:

if args.int8_layers and not validate_int8_layers(args.int8_layers):
    parser.error("Invalid format for --int8_layers. Use comma-separated layer patterns.")
modelopt/onnx/quantization/qdq_utils.py (3)

310-322: Add input validation for the attributes parameter.

The function modifies the attributes dictionary in place, which could cause unexpected side effects if the same dictionary is reused.

 def update_attributes(attrib: dict[str, Any] | None = None, is_per_channel: bool = False):
     """Update attribute dictionary for quantization nodes.
     
     If per-channel quantization is enabled, sets the 'axis' attribute to 1 and removes
     the 'block_size' attribute if present.
+    
+    Args:
+        attrib: Attribute dictionary to update (will be modified in place if not None)
+        is_per_channel: Whether to enable per-channel quantization
+    
+    Returns:
+        The updated attribute dictionary
     """
     if is_per_channel:
         if attrib is not None:
             attrib["axis"] = 1
             if "block_size" in attrib:
                 attrib.pop("block_size")
     return attrib

354-369: Consider extracting tensor dtype selection logic to a helper function.

The tensor dtype selection logic is duplicated in both insert_dq_nodes and insert_qdq_nodes. Consider extracting it to reduce duplication and improve maintainability.

+def get_tensor_dtype(precision_info: dict[str, int] | None, name: str, has_zero_point: bool) -> int:
+    """Get the appropriate tensor dtype based on precision info and zero point presence.
+    
+    Args:
+        precision_info: Dictionary mapping tensor names to bit widths
+        name: Name of the tensor
+        has_zero_point: Whether the tensor has a zero point
+    
+    Returns:
+        ONNX tensor data type constant
+    """
+    if precision_info and name in precision_info:
+        bit_width = precision_info[name]
+        if has_zero_point:
+            dtype_str = onnx_bit_dtype_unsigned_map[bit_width]
+        else:
+            dtype_str = onnx_bit_dtype_signed_map[bit_width]
+        return onnx_dtype_map[dtype_str]
+    else:
+        # Default to 4-bit
+        return onnx.TensorProto.INT4 if not has_zero_point else onnx.TensorProto.UINT4

 def _insert_helper(
     name: str,
     wq: np.ndarray,
     scale: np.ndarray,
     dq_nodes: dict[str, gs.Node],
     zp: np.ndarray,
     attrs: dict[str, Any] | None = None,
     precision_info: dict[str, int] | None = None,
     is_per_channel: bool = False,
 ):
     attrib = dict(attrs) if attrs is not None else None
-    if precision_info and name in precision_info:
-        tensor_dtype = (
-            onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]]
-            if zp is None
-            else onnx_dtype_map[onnx_bit_dtype_unsigned_map[precision_info[name]]]
-        )
-        # do per-channel quantization for int8 as no support for int8 block-wise dq node
-        if precision_info[name] == 8:
-            # reshape scale to be per-channel
-            scale = scale.reshape(-1)
-            attrib = update_attributes(attrib, True)
-    else:
-        tensor_dtype = onnx.TensorProto.INT4 if zp is None else onnx.TensorProto.UINT4
+    tensor_dtype = get_tensor_dtype(precision_info, name, zp is not None)
+    
+    # do per-channel quantization for int8 as no support for int8 block-wise dq node
+    if precision_info and name in precision_info and precision_info[name] == 8:
+        # reshape scale to be per-channel
+        scale = scale.reshape(-1)
+        attrib = update_attributes(attrib, True)

361-364: Validate per‐channel scale size before flattening. Although scale.reshape(-1) preserves all elements, add an explicit check (e.g. assert scale.size == wq.shape[1]) before calling reshape to guard against mismatches when quantizing along axis 1.

modelopt/onnx/quantization/int4.py (1)

620-624: Use logger.warning instead of deprecated logger.warn.

-        logger.warn("Augmented ONNX model or external data file was not found")
+        logger.warning("Augmented ONNX model or external data file was not found")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d6d2e75 and caa9965.

📒 Files selected for processing (4)
  • examples/windows/onnx_ptq/genai_llm/quantize.py (3 hunks)
  • modelopt/onnx/quantization/int4.py (29 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (8 hunks)
  • modelopt/onnx/quantization/quant_utils.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/quant_utils.py (5)
  • _pad (173-187)
  • dq_tensor (297-320)
  • find_scales (204-252)
  • quant_tensor (323-337)
  • rtn (255-294)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (324-406)
  • insert_qdq_nodes (409-463)
modelopt/onnx/quantization/graph_utils.py (1)
  • expand_node_names_from_patterns (628-639)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • num_bits (180-182)
  • num_bits (185-187)
  • axis (279-281)
  • axis (284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (10)
modelopt/onnx/quantization/quant_utils.py (3)

160-167: LGTM! Clean implementation of bit-width selection logic.

The function provides a clear and simple way to get the number of bits for quantization based on the precision info dictionary, with a sensible default of 4 bits.


173-188: LGTM! Proper padding implementation with clear assertions.

The padding logic correctly handles different axes and includes appropriate assertions for validation.


297-321: LGTM! Clear dequantization implementation.

The dequantization logic correctly handles both signed and unsigned cases with proper padding and depadding.

examples/windows/onnx_ptq/genai_llm/quantize.py (2)

368-369: LGTM! Status output correctly includes the new k_quant_mixed flag.

The logging statement properly shows the new mixed-precision quantization option.


438-439: LGTM! Proper propagation of mixed-precision parameters.

The new parameters are correctly passed to the quantize_int4 function.

modelopt/onnx/quantization/qdq_utils.py (2)

57-61: LGTM! Proper extension of dtype mappings for 4-bit support.

The addition of INT4 and UINT4 to the dtype maps enables proper 4-bit quantization support.


432-440: LGTM! Consistent handling of per-weight precision in QDQ nodes.

The implementation correctly mirrors the DQ node logic for handling per-weight precision information.

modelopt/onnx/quantization/int4.py (3)

206-209: LGTM: precision_info is correctly threaded in RTN path.

Scale computation, quantization, and Q/DQ insertion honor per-weight precision and per-channel when applicable.

Also applies to: 240-243, 248-255, 259-265


315-322: LGTM: precision_info propagation in AWQ-clip compute path.

The clip search and final quantization use per-weight precision consistently.

Also applies to: 362-366, 526-528, 552-555


705-707: LGTM: precision_info propagation in AWQ-lite compute path.

Both per-node and per-subgraph scale searches respect mixed precision; quantization/dequantization calls pass the context correctly.

Also applies to: 762-771, 782-783, 882-884, 939-952, 1112-1131, 1182-1188

Comment on lines 1311 to 1283
dq_node_attributes = {"axis": 0, "block_size": block_size}

qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
precision_info=precision_info,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Per-channel attribute not propagated in AWQ-lite.

When block_size == -1, you set per-channel in RTN and AWQ-clip, but not here. TRT plugins require this attribute for correct dequantization.

Apply:

     dq_node_attributes = {"axis": 0, "block_size": block_size}
+    is_per_channel = block_size == -1

     qdq.insert_dq_nodes(
         graph_gs,
         scales,
         quantized_weights=gemm_weights_quantized,
         attributes=dq_node_attributes,
         zero_points=zero_points if use_zero_point else None,
-        precision_info=precision_info,
+        precision_info=precision_info,
+        is_per_channel=is_per_channel,
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dq_node_attributes = {"axis": 0, "block_size": block_size}
qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
precision_info=precision_info,
)
dq_node_attributes = {"axis": 0, "block_size": block_size}
is_per_channel = block_size == -1
qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
precision_info=precision_info,
is_per_channel=is_per_channel,
)
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1311-1320, the
dq_node_attributes only contains "axis" and "block_size" for AWQ-lite; when
block_size == -1 you must also set the per-channel flag (as done for RTN and
AWQ-clip) so TRT plugins get correct dequantization info. Modify the
construction of dq_node_attributes to include "per_channel": True when
block_size == -1 (otherwise keep it False or omit), then pass that attributes
dict into qdq.insert_dq_nodes so the per-channel attribute is propagated.

Comment on lines 1357 to 1377
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
"""Check if layer should be quantized to INT8.
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.
This allows us to correctly identify which weights should be quantized to INT8.
"""
if not int8_layers:
return False
normalized_patterns = []
for pattern in int8_layers:
p = pattern.lstrip("/")
p = p.replace("/", ".")
normalized_patterns.append(p)
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid substring false-positives when matching INT8 layers.

in substring matching can misclassify e.g., layers.1 matches layers.11. Use token-exact matching.

-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
+def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
@@
-    normalized_patterns = []
-    for pattern in int8_layers:
-        p = pattern.lstrip("/")
-        p = p.replace("/", ".")
-        normalized_patterns.append(p)
-    return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)
+    # Normalize both to dot-delimited tokens and require exact token sequence match.
+    def tokens(s: str) -> list[str]:
+        return s.lstrip("/").replace("/", ".").split(".")
+    hay = tokens(layer_name)
+    for pat in int8_layers:
+        needle = tokens(pat)
+        n, m = len(hay), len(needle)
+        for i in range(0, n - m + 1):
+            if hay[i : i + m] == needle:
+                return True
+    return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
"""Check if layer should be quantized to INT8.
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.
This allows us to correctly identify which weights should be quantized to INT8.
"""
if not int8_layers:
return False
normalized_patterns = []
for pattern in int8_layers:
p = pattern.lstrip("/")
p = p.replace("/", ".")
normalized_patterns.append(p)
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
"""Check if layer should be quantized to INT8.
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
To match these, we:
- Remove the leading slash from the node name.
- Replace all '/' with '.' to match the naming convention of the initializer.
This allows us to correctly identify which weights should be quantized to INT8.
"""
if not int8_layers:
return False
# Normalize both to dot-delimited tokens and require exact token sequence match.
def tokens(s: str) -> list[str]:
return s.lstrip("/").replace("/", ".").split(".")
hay = tokens(layer_name)
for pat in int8_layers:
needle = tokens(pat)
n, m = len(hay), len(needle)
for i in range(0, n - m + 1):
if hay[i : i + m] == needle:
return True
return False

Comment on lines 1379 to 1387
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
graph = onnx_model.graph

nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Don’t use a mutable list as a default argument.

nodes_to_exclude: list[str] | None = [r"/lm_head"] risks cross-call mutation.

-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
     onnx_model: onnx.ModelProto,
     int8_precision_pattern: str | None = None,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
 ):
-    graph = onnx_model.graph
+    graph = onnx_model.graph
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]

Optional: allow regex patterns in int8_precision_pattern to be expanded via expand_node_names_from_patterns for robustness.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
graph = onnx_model.graph
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = None,
):
graph = onnx_model.graph
nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
# Collect quantizable weight tensors
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1379 to 1387, the function
signature uses a mutable default list for nodes_to_exclude which can lead to
cross-call mutation; change the signature to use None as the default
(nodes_to_exclude: list[str] | None = None) and inside the function set
nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is None before calling
expand_node_names_from_patterns; additionally (optional) if
int8_precision_pattern should accept regexes, pass it through
expand_node_names_from_patterns or a similar helper so patterns are expanded
consistently before use.

Comment on lines 255 to 326
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
+ zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve numerical stability for edge cases.

When using zero-point quantization, division by s could cause issues if s contains values close to zero (even though clipped by CLIP_MIN). Consider adding a safety check.

 def rtn(
     w: np.ndarray,
     s: np.ndarray,
     block_size: int,
     quantize_axis: int = 0,
     zp: np.ndarray = None,
     precision_info: dict[str, int] | None = None,
     name: str | None = None,
 ) -> np.ndarray:
     """Quantizes `w` with scale factors `s` via Round-to-Nearest.
     
     Ties are broken by rounding to the nearest even number.
     """
     num_bits = get_num_bits(precision_info, name)
     # If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
     # set block_size to the size of the quantize_axis dimension to do per-channel quantization
     if block_size == -1 or num_bits == 8:
         block_size = w.shape[quantize_axis]
     w_padded = _pad(w, block_size, quantize_axis)
     num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
+    
+    # Ensure scale values are not too small to avoid numerical instability
+    if np.any(np.abs(s) < CLIP_MIN):
+        raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}")
+    
     if zp is None:
         maxq = 2 ** (num_bits - 1) - 1
         minq = -(2 ** (num_bits - 1))
         w_padded = (
             np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
             .clip(minq, maxq)
             .astype(np.int8)
         )
     else:
         maxq = (2**num_bits) - 1
         minq = 0
         w_padded = (
             (
                 np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
                 + zp.repeat(num_blocks, axis=quantize_axis)
             )
             .clip(minq, maxq)
             .astype(np.int8)
         )
     return _depad(w_padded, w.shape, quantize_axis)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
+ zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
# Ensure scale values are not too small to avoid numerical instability
if np.any(np.abs(s) < CLIP_MIN):
raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}")
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)

)

parser.add_argument(
"--k_quant_mixed",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does k stands for in k_quant_mixed?

Copy link
Contributor Author

@ynankani ynankani Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k stands for groupwise/block-wise quantization. We are doing Int4 block-wise + few layers in Int8 per-channel quantization (currently we don't have Int8 block-wise DequantizeLinear node support on trt-rtx side) .

@ynankani
Copy link
Contributor Author

Need to make few more changes as per discussion with Vishal. Will request for review once the updated changes are ready.

@ynankani ynankani force-pushed the ynankani/mixed_precision-int4-int8 branch from caa9965 to b6a39be Compare September 12, 2025 18:22
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)

417-450: Zero-point dtype should be unsigned.

When creating Q/DQ, zp should use the unsigned dtype. Pass has_zero_point=True to get_tensor_dtype.

-        tensor_dtype = get_tensor_dtype(num_bits)
+        tensor_dtype = get_tensor_dtype(num_bits)
@@
-        zp_tensor = make_gs_zp(name, scale.shape, tensor_dtype)
+        zp_tensor = make_gs_zp(name, scale.shape, get_tensor_dtype(num_bits, has_zero_point=True))
♻️ Duplicate comments (1)
modelopt/onnx/quantization/quant_utils.py (1)

292-326: Numerical safety in rtn.

Add a guard to avoid tiny s after broadcast.

-    if zp is None:
+    # Safety: avoid division by tiny scales
+    if np.any(np.abs(s) < CLIP_MIN):
+        s = np.where(np.abs(s) < CLIP_MIN, CLIP_MIN, s)
+    if zp is None:
         maxq = 2 ** (num_bits - 1) - 1
🧹 Nitpick comments (7)
modelopt/onnx/quantization/graph_utils.py (2)

629-667: Type hint misuse and minor robustness.

  • cast("int", weight_tensor.data_type) uses a string literal as a type param; use cast(int, ...) or drop cast.
  • Consider guarding for Gemm transA (currently TODO) or explicitly documenting unsupported cases.
-        gemm_io_type = cast("int", weight_tensor.data_type)
+        gemm_io_type = cast(int, weight_tensor.data_type)

669-696: Token-exact matcher looks good. Add fast exit and docstring type.

Add -> bool and early return for long patterns.

-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
+def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
@@
-    for pat in int8_layers:
+    for pat in int8_layers:
         needle = tokens(pat)
         n, m = len(hay), len(needle)
+        if m == 0 or m > n:
+            continue
modelopt/onnx/quantization/quant_utils.py (2)

160-177: update_block_size: clarify contract for w=None.

Function dereferences w unconditionally when block_size==-1 or num_bits==8. Either assert or handle None.

-    if block_size is not None and (block_size == -1 or num_bits == 8):
-        return w.shape[quantize_axis]
+    if block_size is not None and (block_size == -1 or num_bits == 8):
+        assert w is not None, "update_block_size requires `w` when per-channel/INT8 is requested"
+        return w.shape[quantize_axis]

196-198: Type hint for _next_block_size_multiple.

Return type is int in practice; adjust for readability/tools.

-def _next_block_size_multiple(x: float, block_size: int) -> float:
+def _next_block_size_multiple(x: float, block_size: int) -> int:
modelopt/onnx/quantization/qdq_utils.py (1)

50-63: ONNX 4-bit enums: confirm availability.

onnx.TensorProto.INT4/UINT4 are relatively new; some environments may lack them. Add a fallback or guard.

Option: probe availability once and raise a clear error.

-onnx_dtype_map = {
+onnx_dtype_map = {
@@
-    "INT4": onnx.TensorProto.INT4,
-    "UINT4": onnx.TensorProto.UINT4,
+    "INT4": getattr(onnx.TensorProto, "INT4", None),
+    "UINT4": getattr(onnx.TensorProto, "UINT4", None),
 }
+
+# Validate availability early
+if onnx_dtype_map["INT4"] is None or onnx_dtype_map["UINT4"] is None:
+    logger.error("This ONNX build does not support 4-bit dtypes (INT4/UINT4).")
+    # Optionally: raise, or fall back to INT8 path depending on product decision

Would you like me to wire a clean fallback to INT8 when 4-bit enums are missing?

modelopt/onnx/quantization/int4.py (2)

221-227: Pattern-vs-name expansion mismatch risk.

nodes_to_exclude is already expanded to concrete names before calling get_precision_info, which expects patterns (it internally expands again). This is benign but redundant and can mis-handle special regex chars in names.

Two options:

  • Pass the original patterns into get_precision_info, or
  • Change get_precision_info to accept already-expanded names and skip expansion.

1369-1376: Doc/type mismatch for int8_layers.

Docstring says “comma-separated list” but default shows [] and the code expects a string or None. Clarify.

-                - **int8_layers** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
-                                              Default: [].
+                - **int8_layers** (str | None): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
+                                              Default: None.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between caa9965 and b6a39be.

📒 Files selected for processing (6)
  • examples/windows/onnx_ptq/genai_llm/README.md (1 hunks)
  • examples/windows/onnx_ptq/genai_llm/quantize.py (3 hunks)
  • modelopt/onnx/quantization/graph_utils.py (2 hunks)
  • modelopt/onnx/quantization/int4.py (27 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (8 hunks)
  • modelopt/onnx/quantization/quant_utils.py (3 hunks)
✅ Files skipped from review due to trivial changes (1)
  • examples/windows/onnx_ptq/genai_llm/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/windows/onnx_ptq/genai_llm/quantize.py
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
  • get_num_bits (179-193)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (2)
  • _find_quantizable_weights (629-666)
  • get_precision_info (814-843)
modelopt/onnx/quantization/quant_utils.py (8)
  • _pad (200-214)
  • dq_tensor (328-344)
  • find_scales (242-289)
  • get_num_bits (179-193)
  • quant_tensor (347-366)
  • rtn (292-325)
  • update_block_size (160-176)
  • update_scale_map_for_per_channel_nodes (231-239)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (327-399)
  • insert_qdq_nodes (402-456)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • num_bits (180-182)
  • num_bits (185-187)
  • axis (279-281)
  • axis (284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/onnx/quantization/int4.py (2)

562-571: Per-channel handling is correctly propagated.

Good use of update_scale_map_for_per_channel_nodes and attribute adjustment in qdq insertion to set axis=1 and drop block_size for INT8/per-channel. This aligns TRT plugin expectations.

Also applies to: 1264-1273


244-253: Gather path: mixed-precision propagation looks correct.

You compute per-weight num_bits and reuse update_block_size for gather. Nice reuse and consistency across RTN/AWQ paths.

Also applies to: 1253-1261

Comment on lines 709 to 714
def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
):
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid mutable default args and normalize exclusions.

Using a list default (nodes_to_exclude=[r"/lm_head"]) risks cross-call mutation. Also normalize nodes_to_exclude up-front.

Apply:

-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
     onnx_model: onnx.ModelProto,
     int8_precision_pattern: str | None = None,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
 ):
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)

Also applies to: 728-731

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 709-714 (and similarly
at 728-731), avoid using a mutable list as a default for nodes_to_exclude:
change the signature to use nodes_to_exclude: list[str] | None = None, then
inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is
None else list(nodes_to_exclude) to prevent shared-state mutations; immediately
normalize entries (e.g., strip/ensure raw strings or compiled regex as the code
expects) so downstream logic can assume a consistent list type and format.

Comment on lines 814 to 842
def get_precision_info(
onnx_model: onnx.ModelProto,
nodes_to_exclude: list[str] | None = [r"/lm_head"],
**kwargs: Any,
):
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits).
This function determines the quantization precision for each weight tensor in the ONNX model,
based on the provided configuration. If mixed quantization is enabled, it uses the layer
precision mapping; otherwise, it returns None.
Args:
onnx_model (onnx.ModelProto): The ONNX model to analyze.
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
**kwargs: Additional keyword arguments, such as:
- enable_mixed_quant (bool): Whether to enable mixed quantization.
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8.
Returns:
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
or None if mixed quantization is not enabled.
"""
precision_info = None
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
int8_layers = kwargs.get("int8_layers")
if enable_mixed_quant:
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude)
else:
precision_info = None
return precision_info

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Same mutable default issue in public API.

Apply same fix as above.

-def get_precision_info(
+def get_precision_info(
     onnx_model: onnx.ModelProto,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
     **kwargs: Any,
 ):
@@
-    enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    enable_mixed_quant = kwargs.get("enable_mixed_quant", False)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 814 to 844, the
function get_precision_info declares a mutable default parameter
nodes_to_exclude = [r"/lm_head"], which should be avoided; change the signature
to use nodes_to_exclude: list[str] | None = None (or Optional[list[str]] = None)
and inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude
is None, preserving the original default behavior while avoiding the mutable
default; update any type hints and callers if necessary.

Comment on lines +311 to +325
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
"""Get the appropriate tensor dtype based on precision info and zero point presence.
Args:
num_bits: Number of bits for quantization
has_zero_point: Whether the tensor has a zero point
Returns:
ONNX tensor data type constant
"""
if has_zero_point:
dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
else:
dtype_str = onnx_bit_dtype_signed_map[num_bits]
return onnx_dtype_map[dtype_str]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate num_bits and map lookups.

Guard against unexpected num_bits and derive dtype accordingly.

 def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
@@
-    if has_zero_point:
-        dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
-    else:
-        dtype_str = onnx_bit_dtype_signed_map[num_bits]
+    valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys()
+    if num_bits not in valid_bits:
+        raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
+    dtype_str = (
+        onnx_bit_dtype_unsigned_map[num_bits]
+        if has_zero_point
+        else onnx_bit_dtype_signed_map[num_bits]
+    )
     return onnx_dtype_map[dtype_str]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
"""Get the appropriate tensor dtype based on precision info and zero point presence.
Args:
num_bits: Number of bits for quantization
has_zero_point: Whether the tensor has a zero point
Returns:
ONNX tensor data type constant
"""
if has_zero_point:
dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
else:
dtype_str = onnx_bit_dtype_signed_map[num_bits]
return onnx_dtype_map[dtype_str]
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
"""Get the appropriate tensor dtype based on precision info and zero point presence.
Args:
num_bits: Number of bits for quantization
has_zero_point: Whether the tensor has a zero point
Returns:
ONNX tensor data type constant
"""
valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys()
if num_bits not in valid_bits:
raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
dtype_str = (
onnx_bit_dtype_unsigned_map[num_bits]
if has_zero_point
else onnx_bit_dtype_signed_map[num_bits]
)
return onnx_dtype_map[dtype_str]
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 311 to 325, the function
does unguarded lookups on num_bits and mapping dicts which can raise KeyError
for unexpected values; validate that num_bits is one of the supported keys
(e.g., check presence in onnx_bit_dtype_unsigned_map and
onnx_bit_dtype_signed_map) before using it, raise a clear ValueError if
unsupported, and guard the final onnx_dtype_map lookup similarly (or provide a
default/raise with context) so the function always fails fast with a clear
message instead of crashing with a KeyError.

Copy link

codecov bot commented Sep 12, 2025

Codecov Report

❌ Patch coverage is 57.36434% with 110 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.48%. Comparing base (3d2004b) to head (6201e7b).

Files with missing lines Patch % Lines
modelopt/onnx/quantization/graph_utils.py 27.08% 70 Missing ⚠️
modelopt/onnx/quantization/quant_utils.py 70.32% 27 Missing ⚠️
modelopt/onnx/quantization/int4.py 83.33% 7 Missing ⚠️
modelopt/onnx/quantization/qdq_utils.py 79.31% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #310      +/-   ##
==========================================
- Coverage   73.72%   73.48%   -0.24%     
==========================================
  Files         172      172              
  Lines       17484    17636     +152     
==========================================
+ Hits        12890    12960      +70     
- Misses       4594     4676      +82     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ynankani ynankani requested a review from i-riyad September 16, 2025 03:59
Comment on lines 600 to 606
"--enable_mixed_quant",
default=False,
action="store_true",
help="True when we want to use mixed quantization",
)
parser.add_argument(
"--int8_layers",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use just assume mixed_quant enabled if --int8_layers is non-empty? And remove the --enable_mixed_quant option?

Copy link
Contributor Author

@ynankani ynankani Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if --int8_layers is specified we will select only those layers to be quantized in int8 which match the patter in int8_layers, else if only --enable_mixed_quant is specified we hardcode select few important layers similar to what some other quantization tools like model builder/llama.cpp are doing.
example:
if "python quantize.py ... -int8_layer="layer.0" --enable_mixed_quant" => all layer.0 will be quantized to int8
else "python quantize.py ... --enable_mixed_quant" => quantize to int8 the first 1/8, last 1/8 and every 3rd layer for below attn and ffn layers.
/model/layers.{i}/attn/qkv_proj/MatMul
/model/layers.{i}/attn/v_proj/MatMul
/model/layers.{i}/mlp/down_proj/MatMul

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, make sense now.

@ynankani ynankani requested a review from i-riyad September 18, 2025 03:49
default="",
help=(
"Comma-separated list of layer patterns to quantize to INT8 instead of INT4."
"Example: 'layers.0,layers.1,lm_head'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Overrides default mixed quant strategy. Example: 'layers.0,lm_head

"--enable_mixed_quant",
default=False,
action="store_true",
help="True when we want to use mixed quantization",
Copy link
Contributor

@i-riyad i-riyad Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: help="Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd layer quantized to INT8; others to INT4.",

…+INT8) ONNX models, refactored changes and handle comments

Signed-off-by: unknown <[email protected]>
…+INT8) ONNX models, handle comments and rename functions,variables

Signed-off-by: unknown <[email protected]>
@ynankani ynankani force-pushed the ynankani/mixed_precision-int4-int8 branch from b6a39be to b069cec Compare September 21, 2025 09:18
@ynankani ynankani requested a review from i-riyad September 21, 2025 09:20
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (12)
examples/windows/onnx_ptq/genai_llm/quantize.py (2)

364-366: Harden the gate that flips mixed-quant on when --layers_8bit is provided.

Trim whitespace and only flip when non-empty; also normalize the string once.

-    if args.layers_8bit:
-        args.enable_mixed_quant = True
+    if args.layers_8bit and args.layers_8bit.strip():
+        args.layers_8bit = args.layers_8bit.strip()
+        args.enable_mixed_quant = True

603-617: Tighten CLI help and document implied behavior.

Clarify that providing --layers_8bit implies mixed quant, and fix example punctuation.

-    parser.add_argument(
-        "--enable_mixed_quant",
+    parser.add_argument(
+        "--enable_mixed_quant",
         default=False,
         action="store_true",
         help=(
-            "Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd attn, "
-            "mlp layers quantized to 8 bits; others to 4 bits."
+            "Enable mixed quantization. Default strategy: first 1/8, last 1/8, and every 3rd "
+            "attn/mlp layer quantized to INT8; others to INT4."
         ),
     )
-    parser.add_argument(
+    parser.add_argument(
         "--layers_8bit",
         type=str,
         default="",
-        help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"),
+        help=("Overrides default strategy. Comma-separated patterns; providing this implies "
+              "--enable_mixed_quant. Example: 'layers.0,lm_head'."),
     )
modelopt/onnx/quantization/int4.py (4)

180-186: Avoid mutable default argument (nodes_to_exclude).

Using [] as a default risks cross-call mutation.

-def quantize_rtn(
+def quantize_rtn(
     onnx_model: onnx.ModelProto,
     block_size: int,
     dq_only: bool = False,
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
     **kwargs: Any,
 ) -> onnx.ModelProto:
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or []
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)

439-448: Same mutable default issue in _quantize_awq_clip.

-def _quantize_awq_clip(
+def _quantize_awq_clip(
     onnx_model: onnx.ModelProto,
     data_reader: CalibrationDataReader,
     use_external_data_format: bool,
     calibration_eps: list[str],
     block_size: int,
     force_fp16: bool = False,
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
     input_shapes_profile: Sequence[dict[str, str]] | None = None,
     **kwargs: Any,
 ) -> onnx.ModelProto:
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or []
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)

959-972: Same mutable default issue in _quantize_awq_lite.

-def _quantize_awq_lite(
+def _quantize_awq_lite(
     onnx_model: onnx.ModelProto,
     data_reader: CalibrationDataReader,
     use_external_data_format: bool,
     calibration_eps: list[str],
     block_size: int,
     force_fp16: bool = False,
     enable_fast_path_using_high_sysram: bool = False,
     enable_weight_clipping: bool = False,
     use_zero_point: bool = False,
-    nodes_to_exclude: list[str] = [],
+    nodes_to_exclude: list[str] | None = None,
     input_shapes_profile: Sequence[dict[str, str]] | None = None,
     **kwargs: Any,
 ) -> onnx.ModelProto:
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or []
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)

604-604: Use logger.warning instead of deprecated logger.warn.

-        logger.warn("Augmented ONNX model or external data file was not found")
+        logger.warning("Augmented ONNX model or external data file was not found")
modelopt/onnx/quantization/qdq_utils.py (1)

311-325: Fail fast for unsupported bit-widths in get_tensor_dtype.

Avoid KeyError by validating num_bits and mapping lookups.

 def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
@@
-    if has_zero_point:
-        dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
-    else:
-        dtype_str = onnx_bit_dtype_signed_map[num_bits]
-    return onnx_dtype_map[dtype_str]
+    valid_bits = set(onnx_bit_dtype_unsigned_map) & set(onnx_bit_dtype_signed_map)
+    if num_bits not in valid_bits:
+        raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
+    dtype_str = (
+        onnx_bit_dtype_unsigned_map[num_bits]
+        if has_zero_point
+        else onnx_bit_dtype_signed_map[num_bits]
+    )
+    if dtype_str not in onnx_dtype_map:
+        raise ValueError(f"Unsupported dtype '{dtype_str}' in onnx_dtype_map")
+    return onnx_dtype_map[dtype_str]
modelopt/onnx/quantization/graph_utils.py (4)

707-714: Avoid mutable default (nodes_to_exclude) and normalize early.

-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
     onnx_model: onnx.ModelProto,
     precision_pattern_8bit: str | None = None,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
 ):
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)

812-816: Same mutable default in get_precision_info.

-def get_precision_info(
+def get_precision_info(
     onnx_model: onnx.ModelProto,
-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
     **kwargs: Any,
 ):
@@
-    precision_info = None
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+    precision_info = None

698-705: Slightly stricter validation using fullmatch.

Prevents accidental partial matches.

-    return bool(re.match(pattern, layers_str))
+    return bool(re.fullmatch(pattern, layers_str))

844-856: Consider re.fullmatch for pattern expansion (optional).

If callers pass anchored patterns it's fine; fullmatch would make intent explicit.

-        matched_node_names.extend([node.name for node in node_list if re.match(pattern, node.name)])
+        matched_node_names.extend([node.name for node in node_list if re.fullmatch(pattern, node.name)])
modelopt/onnx/quantization/quant_utils.py (1)

160-177: Guard update_block_size when w is required.

When block_size == -1 or num_bits == 8, w must be provided.

 def update_block_size(
     num_bits: int, block_size: int, quantize_axis: int = 0, w: np.ndarray = None
 ) -> int:
@@
-    if block_size is not None and (block_size == -1 or num_bits == 8):
-        return w.shape[quantize_axis]
+    if block_size is not None and (block_size == -1 or num_bits == 8):
+        assert w is not None, "Weight tensor 'w' is required to resolve per-channel block size"
+        return w.shape[quantize_axis]
     return block_size
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b6a39be and b069cec.

📒 Files selected for processing (6)
  • examples/windows/onnx_ptq/genai_llm/README.md (1 hunks)
  • examples/windows/onnx_ptq/genai_llm/quantize.py (3 hunks)
  • modelopt/onnx/quantization/graph_utils.py (2 hunks)
  • modelopt/onnx/quantization/int4.py (27 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (8 hunks)
  • modelopt/onnx/quantization/quant_utils.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/windows/onnx_ptq/genai_llm/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
  • get_num_bits (179-193)
modelopt/onnx/quantization/int4.py (4)
modelopt/onnx/quantization/graph_utils.py (2)
  • _find_int4_quantizable_weights (629-666)
  • get_precision_info (812-841)
modelopt/onnx/quantization/ort_utils.py (1)
  • create_inference_session (213-245)
modelopt/onnx/quantization/quant_utils.py (8)
  • _pad (200-214)
  • dq_tensor (328-344)
  • find_scales (242-289)
  • get_num_bits (179-193)
  • quant_tensor (347-365)
  • reshape_scales_for_per_channel_nodes (231-239)
  • rtn (292-325)
  • update_block_size (160-176)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (349-419)
  • insert_qdq_nodes (422-476)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (7)
examples/windows/onnx_ptq/genai_llm/quantize.py (1)

428-444: LGTM: plumbing new flags into quantize_int4.

Flags are passed through consistently and printed for visibility.

modelopt/onnx/quantization/int4.py (1)

221-231: LGTM: precision_info propagation and per-channel handling.

Per-weight bits flow cleanly into scale computation, Q/DQ insertion, and attribute updates.

Also applies to: 272-279, 293-300, 572-579, 1274-1283, 1290-1297

modelopt/onnx/quantization/qdq_utils.py (2)

349-356: LGTM: per-weight dtype and attrs wired into DQ/QDQ insertion.

Correctly derives dtype from num_bits and updates attributes for per‑channel.

Also applies to: 422-427


58-63: Verify ONNX/ORT support for INT4/UINT4 — verification failed here (onnx not installed).

Confirm CI/runtime ONNX >= 1.16 and that the ORT build exposes TensorProto.INT4/UINT4.

Affects: modelopt/onnx/quantization/qdq_utils.py (lines 58–63, 311–325)

Run locally/CI and paste output:

python - <<'PY'
import sys
try:
    import onnx, onnxruntime as ort
except Exception as e:
    print("ERROR:", e, file=sys.stderr); sys.exit(2)
print("ONNX version:", onnx.__version__)
print("Has INT4:", hasattr(onnx.TensorProto, "INT4"))
print("Has UINT4:", hasattr(onnx.TensorProto, "UINT4"))
print("ORT version:", ort.__version__)
PY
modelopt/onnx/quantization/graph_utils.py (1)

669-696: LGTM: token‑exact INT8 layer matching avoids substring false positives.

Good robustness improvement over naive substring checks.

modelopt/onnx/quantization/quant_utils.py (2)

200-215: LGTM: fixed quantize_axis bounds check in _pad.

Prevents out-of-bounds indexing.


242-290: LGTM: scale computation clamps to CLIP_MIN and validates zero-points.

Improves numerical stability and correctness for affine quantization.

@kevalmorabia97 kevalmorabia97 requested review from vishalpandya1990 and removed request for kevalmorabia97 September 22, 2025 06:57
Copy link
Contributor

@i-riyad i-riyad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

@vishalpandya1990 vishalpandya1990 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ynankani ynankani force-pushed the ynankani/mixed_precision-int4-int8 branch from b069cec to 6201e7b Compare September 23, 2025 06:33
Copy link

copy-pr-bot bot commented Sep 23, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
modelopt/onnx/quantization/int4.py (3)

149-156: Keep zero‑point dtype consistent (don’t cast zp to weight dtype).

Pass zp unchanged to rtn. Casting zp to w.dtype can lead to subtle rounding differences; zp is an integer offset by definition.

Apply:

-                    qw = rtn(
-                        np.asarray(w),
-                        s,
-                        block_size_updated,
-                        quantize_axis=gather_quantize_axis,
-                        zp=zp if zp is None else zp.astype(w.dtype),
-                        num_bits=num_bits,
-                    )
+                    qw = rtn(
+                        np.asarray(w),
+                        s,
+                        block_size_updated,
+                        quantize_axis=gather_quantize_axis,
+                        zp=zp,
+                        num_bits=num_bits,
+                    )

1327-1333: Avoid mutable default for nodes_to_exclude in public API.

Use None default to prevent cross‑call mutation.

Apply:

-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
@@
-    nodes_to_exclude = nodes_to_exclude or []
+    nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]

97-101: Remove unused CLIP_MIN constant here (dup in quant_utils).

It’s duplicated and unused in this module.

Apply:

-# following min-value for clip is taken from AutoAWQ where zero-point based quantization is
-# supported and working
-CLIP_MIN = 1e-5
modelopt/onnx/quantization/graph_utils.py (2)

698-705: Regex is fine; consider re.fullmatch for clarity (optional).

Anchors already enforce full‑string match; switching to re.fullmatch would be clearer but not required.


707-711: Avoid mutable default list for nodes_to_exclude.

Prevent shared state across calls.

Apply:

-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
@@
-    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+    nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [r"/lm_head"])
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b069cec and 6201e7b.

📒 Files selected for processing (6)
  • examples/windows/onnx_ptq/genai_llm/README.md (1 hunks)
  • examples/windows/onnx_ptq/genai_llm/quantize.py (3 hunks)
  • modelopt/onnx/quantization/graph_utils.py (2 hunks)
  • modelopt/onnx/quantization/int4.py (27 hunks)
  • modelopt/onnx/quantization/qdq_utils.py (8 hunks)
  • modelopt/onnx/quantization/quant_utils.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/windows/onnx_ptq/genai_llm/README.md
  • examples/windows/onnx_ptq/genai_llm/quantize.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-22T21:14:46.661Z
Learnt from: gcunhase
PR: NVIDIA/TensorRT-Model-Optimizer#354
File: modelopt/onnx/quantization/qdq_utils.py:874-944
Timestamp: 2025-09-22T21:14:46.661Z
Learning: In ModelOpt's ONNX quantization pipeline, DequantizeLinear nodes always have 3 inputs (data, scale, zero_point) and zero-points are always provided as initializers, not as Constant nodes. This means the `remove_graph_input_q` function can safely assume `dq_node[0].input[2]` exists and is an initializer.

Applied to files:

  • modelopt/onnx/quantization/int4.py
  • modelopt/onnx/quantization/qdq_utils.py
🧬 Code graph analysis (3)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (1)
  • get_precision_info (812-841)
modelopt/onnx/quantization/quant_utils.py (8)
  • _pad (200-214)
  • dq_tensor (328-344)
  • find_scales (242-289)
  • get_num_bits (179-193)
  • quant_tensor (347-365)
  • reshape_scales_for_per_channel_nodes (231-239)
  • rtn (292-325)
  • update_block_size (160-176)
modelopt/onnx/quantization/qdq_utils.py (2)
  • insert_dq_nodes (349-419)
  • insert_qdq_nodes (422-476)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
  • get_num_bits (179-193)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • num_bits (178-180)
  • num_bits (183-185)
  • axis (277-279)
  • axis (282-284)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (19)
modelopt/onnx/quantization/int4.py (6)

124-131: Correct: block size updated per weight bit-width.

Using get_num_bits + update_block_size for Gather is the right call for per-channel INT8.


176-177: Good: per‑channel scale reshape for Gather.

reshape_scales_for_per_channel_nodes correctly handles 8‑bit/per‑channel shaping.


224-231: Mixed‑precision plumbing LGTM.

Threading precision_info through scale/quant and using update_block_size is sound.


283-289: Fix previously reported Gather DQ scales misuse — verified.

You’re now passing gather_s_map (not GEMM scales) into insert_dq_nodes for Gather. Issue addressed.


572-579: AWQ‑clip DQ insertion honors per‑channel attributes.

Using reshape_scales_for_per_channel_nodes + qdq.insert_dq_nodes with precision_info is correct.


976-988: Mixed‑precision disabled for per‑subgraph AWQ‑lite (intent?).

The assertion forbids precision_info with awqlite_run_per_subgraph=True. Confirm that’s intentional and document via a log/warning.

modelopt/onnx/quantization/qdq_utils.py (5)

311-325: Validate num_bits to fail fast with a clear error.

Guard lookups to avoid KeyError for unsupported bit‑widths.

Apply:

 def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
@@
-    if has_zero_point:
-        dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
-    else:
-        dtype_str = onnx_bit_dtype_signed_map[num_bits]
-    return onnx_dtype_map[dtype_str]
+    valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys()
+    if num_bits not in valid_bits:
+        raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
+    dtype_str = (
+        onnx_bit_dtype_unsigned_map[num_bits]
+        if has_zero_point
+        else onnx_bit_dtype_signed_map[num_bits]
+    )
+    return onnx_dtype_map[dtype_str]

327-338: Per‑channel attr update logic is appropriate.

Axis override and dropping block_size for 8‑bit/per‑channel looks good.


349-357: DQ insertion API extension LGTM.

precision_info threading and per‑weight dtype/attrs are correct.


422-469: QDQ insertion handles per‑weight bit‑widths correctly.

Using get_tensor_dtype with num_bits and creating zp/outputs accordingly is sound.


945-1014: Assumption aligns with repo practice: DQ has 3 inputs, zp initializer.

remove_graph_input_q relies on zp being an initializer; your insert_qdq/insert_dq_nodes satisfy this.

modelopt/onnx/quantization/graph_utils.py (3)

669-696: Exact token matching avoids substring false‑positives.

Sliding token window fix is correct and robust for layer matching.


812-841: Same mutable default issue in get_precision_info.

Mirror the fix here.

Apply:

-    nodes_to_exclude: list[str] | None = [r"/lm_head"],
+    nodes_to_exclude: list[str] | None = None,
@@
-    if enable_mixed_quant:
-        precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude)
+    if enable_mixed_quant:
+        precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude or [r"/lm_head"])

735-801: Heuristic 8‑bit selection is reasonable.

Pattern‑based grouping and sampling (first/last eighth + every third) is acceptable as a default.

modelopt/onnx/quantization/quant_utils.py (5)

160-177: Block size updater LGTM.

Handles per‑channel for 8‑bit and -1 sentinel correctly.


200-215: Axis bounds check fixed.

Using 0 <= quantize_axis < len(w.shape) prevents out‑of‑range errors.


242-290: Scale/zero‑point computation handles edge cases.

CLIP_MIN clamp and zp range validation look good.


292-326: Quantization math is consistent with per‑channel handling.

Repeat logic and clipping per num_bits are correct.


231-240: Per‑channel scale reshape helper LGTM.

Conditioning on num_bits==8 or block_size==-1 is appropriate.

@ynankani ynankani enabled auto-merge (squash) September 23, 2025 09:54
@ynankani ynankani disabled auto-merge September 23, 2025 10:00
@ynankani ynankani enabled auto-merge (squash) September 23, 2025 10:00
@ynankani ynankani disabled auto-merge September 23, 2025 10:00
@ynankani ynankani enabled auto-merge (squash) September 23, 2025 10:00
@ynankani ynankani disabled auto-merge September 23, 2025 10:26
@ynankani ynankani enabled auto-merge (squash) September 23, 2025 10:26
@ynankani ynankani force-pushed the ynankani/mixed_precision-int4-int8 branch from 6201e7b to b069cec Compare September 23, 2025 12:09
@ynankani ynankani merged commit 8b1cedf into main Sep 23, 2025
27 checks passed
@ynankani ynankani deleted the ynankani/mixed_precision-int4-int8 branch September 23, 2025 12:09
yeyu-nvidia pushed a commit that referenced this pull request Oct 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants