-
Notifications
You must be signed in to change notification settings - Fork 169
[5506930]Add support in ModelOpt for generating mixed-precision (INT4… #310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds mixed-precision per-weight quantization (INT4/INT8), centralizes block-aware quant math into shared utilities, extends Q/DQ and QDQ insertion to honor per-weight bit widths, adds CLI flags Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as quantize.py (CLI)
participant INT4 as modelopt.onnx.quantization.int4.quantize
participant MAP as graph_utils.get_precision_info
participant PIPE as RTN/AWQ/AWQ-lite
participant QUTIL as quant_utils
participant QDQ as qdq_utils
CLI->>INT4: quantize(..., enable_mixed_quant, layers_8bit)
alt layers_8bit provided
INT4-->>INT4: force enable_mixed_quant = true
end
alt enable_mixed_quant == true
INT4->>MAP: get_precision_info(model, precision_pattern_8bit=layers_8bit)
MAP-->>INT4: precision_info (weight -> 4|8)
else
INT4-->>INT4: precision_info = None
end
INT4->>PIPE: run quant path(..., precision_info, enable_mixed_quant)
PIPE->>QUTIL: find_scales / quant_tensor / rtn(..., num_bits via precision_info)
QUTIL-->>PIPE: quantized weights, scales, zero-points
PIPE->>QDQ: insert_dq_nodes / insert_qdq_nodes(..., precision_info)
QDQ-->>PIPE: nodes with per-weight dtype (INT4/INT8) and zero-point handling
PIPE-->>INT4: updated ONNX model
INT4-->>CLI: write quantized model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
modelopt/onnx/quantization/int4.py (6)
579-587
: Fix: pass a GraphSurgeon graph to _quantize_gather_nodes in AWQ-clip.
_quantize_gather_nodes
iteratesgraph.nodes
(GraphSurgeon), but heregraph
is an ONNX GraphProto (augmented_model.graph
). This will throw at runtime.Apply:
- gather_w_map, gather_s_map, _ = _quantize_gather_nodes( - graph, + gather_w_map, gather_s_map, _ = _quantize_gather_nodes( + graph_gs, nodes_to_exclude, gather_quantize_axis, gather_block_size, use_zero_point=False, dq_only=True, )
1327-1334
: Per-channel attribute missing for Gather DQ nodes in AWQ-lite.Mirror the per-channel handling for Gather weights too.
qdq.insert_dq_nodes( graph_gs, gather_s_map, quantized_weights=gather_w_map, attributes=gather_dq_node_attributes, zero_points=gather_zp_map if use_zero_point else None, - precision_info=precision_info, + precision_info=precision_info, + is_per_channel=is_per_channel, )
159-166
: Avoid mutable default list for nodes_to_exclude in RTN API.-def quantize_rtn( +def quantize_rtn( onnx_model: onnx.ModelProto, block_size: int, dq_only: bool = False, - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])
457-468
: Avoid mutable default list for nodes_to_exclude in AWQ-clip API.-def _quantize_awq_clip( +def _quantize_awq_clip( onnx_model: onnx.ModelProto, @@ - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])
1000-1014
: Avoid mutable default list for nodes_to_exclude in AWQ-lite API.-def _quantize_awq_lite( +def _quantize_awq_lite( @@ - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])
1470-1484
: Avoid mutable default list for nodes_to_exclude; add params doc for mixed precision.-def quantize( +def quantize( @@ - nodes_to_exclude: list[str] | None = [r"/lm_head"], + nodes_to_exclude: list[str] | None = None, @@ - logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}") + nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] + logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}")Optional: Update the docstring to mention
k_quant_mixed
andint8_layers
behavior.
🧹 Nitpick comments (8)
modelopt/onnx/quantization/quant_utils.py (3)
169-171
: Consider inlining this simple calculation.This helper function is only used once (in
_pad
) and performs a trivial calculation. Consider inlining it directly where used to reduce unnecessary abstraction.-def _next_block_size_multiple(x: float, block_size: int) -> float: - return math.ceil(x / block_size) * block_size - def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: """Pads `w` to next largest multiple of block_size, on quantize_axis.""" assert quantize_axis <= len(w.shape), ( f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" ) if w.shape[quantize_axis] % block_size == 0: return w - pad_width = ( - _next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] - ) + pad_width = math.ceil(w.shape[quantize_axis] / block_size) * block_size - w.shape[quantize_axis]
190-201
: Consider using numpy slicing syntax for all axes.The function handles only 2D arrays but could be made more general and consistent using numpy's ellipsis syntax.
def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarray: """Depad quantize_axis to original shape.""" if w.shape == orig_shape: return w - ans = None - if quantize_axis == 0: - ans = w[0 : orig_shape[0], ...] - elif quantize_axis == 1: - ans = w[..., 0 : orig_shape[1]] - else: - raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array") - return ans + + if quantize_axis not in [0, 1]: + raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array") + + slices = [slice(None)] * len(w.shape) + slices[quantize_axis] = slice(0, orig_shape[quantize_axis]) + return w[tuple(slices)]
323-338
: Add docstring details about return values.The function's docstring should specify what the three return values represent.
def quant_tensor( w: np.ndarray, block_size: int, quantize_axis: int = 0, alpha: float = 1.0, use_zero_point: bool = False, precision_info: dict[str, int] | None = None, name: str | None = None, ): - """Quantize a tensor using alpha etc. and return the quantized tensor.""" + """Quantize a tensor using alpha etc. and return the quantized tensor. + + Returns: + tuple: A tuple containing: + - wq: The quantized weight tensor (np.ndarray) + - scale: The scale factors used for quantization (np.ndarray) + - zp: The zero-point values (np.ndarray or None if not using zero-point) + """ scale, zp = find_scales( w, block_size, quantize_axis, alpha, use_zero_point, precision_info, name ) wq = rtn(w, scale, block_size, quantize_axis, zp, precision_info, name) return wq, scale, zpexamples/windows/onnx_ptq/genai_llm/quantize.py (1)
599-613
: Consider validating the int8_layers pattern format.The
int8_layers
parameter accepts a comma-separated string of layer patterns, but there's no validation of the format. Consider adding basic validation to catch user errors early.parser.add_argument( "--int8_layers", type=str, default="", help=( "Comma-separated list of layer patterns to quantize to INT8 instead of INT4." "Example: 'layers.0,layers.1,lm_head'" ), ) + +def validate_int8_layers(layers_str: str) -> bool: + """Validate the format of int8_layers string.""" + if not layers_str: + return True + # Basic validation: check for valid characters and structure + import re + pattern = r'^[a-zA-Z0-9_.,\-]+$' + return bool(re.match(pattern, layers_str))Then add validation after parsing:
if args.int8_layers and not validate_int8_layers(args.int8_layers): parser.error("Invalid format for --int8_layers. Use comma-separated layer patterns.")modelopt/onnx/quantization/qdq_utils.py (3)
310-322
: Add input validation for the attributes parameter.The function modifies the attributes dictionary in place, which could cause unexpected side effects if the same dictionary is reused.
def update_attributes(attrib: dict[str, Any] | None = None, is_per_channel: bool = False): """Update attribute dictionary for quantization nodes. If per-channel quantization is enabled, sets the 'axis' attribute to 1 and removes the 'block_size' attribute if present. + + Args: + attrib: Attribute dictionary to update (will be modified in place if not None) + is_per_channel: Whether to enable per-channel quantization + + Returns: + The updated attribute dictionary """ if is_per_channel: if attrib is not None: attrib["axis"] = 1 if "block_size" in attrib: attrib.pop("block_size") return attrib
354-369
: Consider extracting tensor dtype selection logic to a helper function.The tensor dtype selection logic is duplicated in both
insert_dq_nodes
andinsert_qdq_nodes
. Consider extracting it to reduce duplication and improve maintainability.+def get_tensor_dtype(precision_info: dict[str, int] | None, name: str, has_zero_point: bool) -> int: + """Get the appropriate tensor dtype based on precision info and zero point presence. + + Args: + precision_info: Dictionary mapping tensor names to bit widths + name: Name of the tensor + has_zero_point: Whether the tensor has a zero point + + Returns: + ONNX tensor data type constant + """ + if precision_info and name in precision_info: + bit_width = precision_info[name] + if has_zero_point: + dtype_str = onnx_bit_dtype_unsigned_map[bit_width] + else: + dtype_str = onnx_bit_dtype_signed_map[bit_width] + return onnx_dtype_map[dtype_str] + else: + # Default to 4-bit + return onnx.TensorProto.INT4 if not has_zero_point else onnx.TensorProto.UINT4 def _insert_helper( name: str, wq: np.ndarray, scale: np.ndarray, dq_nodes: dict[str, gs.Node], zp: np.ndarray, attrs: dict[str, Any] | None = None, precision_info: dict[str, int] | None = None, is_per_channel: bool = False, ): attrib = dict(attrs) if attrs is not None else None - if precision_info and name in precision_info: - tensor_dtype = ( - onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]] - if zp is None - else onnx_dtype_map[onnx_bit_dtype_unsigned_map[precision_info[name]]] - ) - # do per-channel quantization for int8 as no support for int8 block-wise dq node - if precision_info[name] == 8: - # reshape scale to be per-channel - scale = scale.reshape(-1) - attrib = update_attributes(attrib, True) - else: - tensor_dtype = onnx.TensorProto.INT4 if zp is None else onnx.TensorProto.UINT4 + tensor_dtype = get_tensor_dtype(precision_info, name, zp is not None) + + # do per-channel quantization for int8 as no support for int8 block-wise dq node + if precision_info and name in precision_info and precision_info[name] == 8: + # reshape scale to be per-channel + scale = scale.reshape(-1) + attrib = update_attributes(attrib, True)
361-364
: Validate per‐channel scale size before flattening. Althoughscale.reshape(-1)
preserves all elements, add an explicit check (e.g.assert scale.size == wq.shape[1]
) before callingreshape
to guard against mismatches when quantizing along axis 1.modelopt/onnx/quantization/int4.py (1)
620-624
: Use logger.warning instead of deprecated logger.warn.- logger.warn("Augmented ONNX model or external data file was not found") + logger.warning("Augmented ONNX model or external data file was not found")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/windows/onnx_ptq/genai_llm/quantize.py
(3 hunks)modelopt/onnx/quantization/int4.py
(29 hunks)modelopt/onnx/quantization/qdq_utils.py
(8 hunks)modelopt/onnx/quantization/quant_utils.py
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/quant_utils.py (5)
_pad
(173-187)dq_tensor
(297-320)find_scales
(204-252)quant_tensor
(323-337)rtn
(255-294)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes
(324-406)insert_qdq_nodes
(409-463)modelopt/onnx/quantization/graph_utils.py (1)
expand_node_names_from_patterns
(628-639)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
num_bits
(180-182)num_bits
(185-187)axis
(279-281)axis
(284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (10)
modelopt/onnx/quantization/quant_utils.py (3)
160-167
: LGTM! Clean implementation of bit-width selection logic.The function provides a clear and simple way to get the number of bits for quantization based on the precision info dictionary, with a sensible default of 4 bits.
173-188
: LGTM! Proper padding implementation with clear assertions.The padding logic correctly handles different axes and includes appropriate assertions for validation.
297-321
: LGTM! Clear dequantization implementation.The dequantization logic correctly handles both signed and unsigned cases with proper padding and depadding.
examples/windows/onnx_ptq/genai_llm/quantize.py (2)
368-369
: LGTM! Status output correctly includes the new k_quant_mixed flag.The logging statement properly shows the new mixed-precision quantization option.
438-439
: LGTM! Proper propagation of mixed-precision parameters.The new parameters are correctly passed to the quantize_int4 function.
modelopt/onnx/quantization/qdq_utils.py (2)
57-61
: LGTM! Proper extension of dtype mappings for 4-bit support.The addition of INT4 and UINT4 to the dtype maps enables proper 4-bit quantization support.
432-440
: LGTM! Consistent handling of per-weight precision in QDQ nodes.The implementation correctly mirrors the DQ node logic for handling per-weight precision information.
modelopt/onnx/quantization/int4.py (3)
206-209
: LGTM: precision_info is correctly threaded in RTN path.Scale computation, quantization, and Q/DQ insertion honor per-weight precision and per-channel when applicable.
Also applies to: 240-243, 248-255, 259-265
315-322
: LGTM: precision_info propagation in AWQ-clip compute path.The clip search and final quantization use per-weight precision consistently.
Also applies to: 362-366, 526-528, 552-555
705-707
: LGTM: precision_info propagation in AWQ-lite compute path.Both per-node and per-subgraph scale searches respect mixed precision; quantization/dequantization calls pass the context correctly.
Also applies to: 762-771, 782-783, 882-884, 939-952, 1112-1131, 1182-1188
dq_node_attributes = {"axis": 0, "block_size": block_size} | ||
|
||
qdq.insert_dq_nodes( | ||
graph_gs, | ||
scales, | ||
quantized_weights=gemm_weights_quantized, | ||
attributes=dq_node_attributes, | ||
zero_points=zero_points if use_zero_point else None, | ||
precision_info=precision_info, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Per-channel attribute not propagated in AWQ-lite.
When block_size == -1
, you set per-channel in RTN and AWQ-clip, but not here. TRT plugins require this attribute for correct dequantization.
Apply:
dq_node_attributes = {"axis": 0, "block_size": block_size}
+ is_per_channel = block_size == -1
qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
- precision_info=precision_info,
+ precision_info=precision_info,
+ is_per_channel=is_per_channel,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
dq_node_attributes = {"axis": 0, "block_size": block_size} | |
qdq.insert_dq_nodes( | |
graph_gs, | |
scales, | |
quantized_weights=gemm_weights_quantized, | |
attributes=dq_node_attributes, | |
zero_points=zero_points if use_zero_point else None, | |
precision_info=precision_info, | |
) | |
dq_node_attributes = {"axis": 0, "block_size": block_size} | |
is_per_channel = block_size == -1 | |
qdq.insert_dq_nodes( | |
graph_gs, | |
scales, | |
quantized_weights=gemm_weights_quantized, | |
attributes=dq_node_attributes, | |
zero_points=zero_points if use_zero_point else None, | |
precision_info=precision_info, | |
is_per_channel=is_per_channel, | |
) |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1311-1320, the
dq_node_attributes only contains "axis" and "block_size" for AWQ-lite; when
block_size == -1 you must also set the per-channel flag (as done for RTN and
AWQ-clip) so TRT plugins get correct dequantization info. Modify the
construction of dq_node_attributes to include "per_channel": True when
block_size == -1 (otherwise keep it False or omit), then pass that attributes
dict into qdq.insert_dq_nodes so the per-channel attribute is propagated.
modelopt/onnx/quantization/int4.py
Outdated
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): | ||
"""Check if layer should be quantized to INT8. | ||
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. | ||
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. | ||
To match these, we: | ||
- Remove the leading slash from the node name. | ||
- Replace all '/' with '.' to match the naming convention of the initializer. | ||
This allows us to correctly identify which weights should be quantized to INT8. | ||
""" | ||
if not int8_layers: | ||
return False | ||
normalized_patterns = [] | ||
for pattern in int8_layers: | ||
p = pattern.lstrip("/") | ||
p = p.replace("/", ".") | ||
normalized_patterns.append(p) | ||
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid substring false-positives when matching INT8 layers.
in
substring matching can misclassify e.g., layers.1
matches layers.11
. Use token-exact matching.
-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
+def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
@@
- normalized_patterns = []
- for pattern in int8_layers:
- p = pattern.lstrip("/")
- p = p.replace("/", ".")
- normalized_patterns.append(p)
- return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)
+ # Normalize both to dot-delimited tokens and require exact token sequence match.
+ def tokens(s: str) -> list[str]:
+ return s.lstrip("/").replace("/", ".").split(".")
+ hay = tokens(layer_name)
+ for pat in int8_layers:
+ needle = tokens(pat)
+ n, m = len(hay), len(needle)
+ for i in range(0, n - m + 1):
+ if hay[i : i + m] == needle:
+ return True
+ return False
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): | |
"""Check if layer should be quantized to INT8. | |
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. | |
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. | |
To match these, we: | |
- Remove the leading slash from the node name. | |
- Replace all '/' with '.' to match the naming convention of the initializer. | |
This allows us to correctly identify which weights should be quantized to INT8. | |
""" | |
if not int8_layers: | |
return False | |
normalized_patterns = [] | |
for pattern in int8_layers: | |
p = pattern.lstrip("/") | |
p = p.replace("/", ".") | |
normalized_patterns.append(p) | |
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns) | |
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool: | |
"""Check if layer should be quantized to INT8. | |
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. | |
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. | |
To match these, we: | |
- Remove the leading slash from the node name. | |
- Replace all '/' with '.' to match the naming convention of the initializer. | |
This allows us to correctly identify which weights should be quantized to INT8. | |
""" | |
if not int8_layers: | |
return False | |
# Normalize both to dot-delimited tokens and require exact token sequence match. | |
def tokens(s: str) -> list[str]: | |
return s.lstrip("/").replace("/", ".").split(".") | |
hay = tokens(layer_name) | |
for pat in int8_layers: | |
needle = tokens(pat) | |
n, m = len(hay), len(needle) | |
for i in range(0, n - m + 1): | |
if hay[i : i + m] == needle: | |
return True | |
return False |
modelopt/onnx/quantization/int4.py
Outdated
def get_layer_precision_mapping( | ||
onnx_model: onnx.ModelProto, | ||
int8_precision_pattern: str | None = None, | ||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
): | ||
graph = onnx_model.graph | ||
|
||
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) | ||
# Collect quantizable weight tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Don’t use a mutable list as a default argument.
nodes_to_exclude: list[str] | None = [r"/lm_head"]
risks cross-call mutation.
-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
):
- graph = onnx_model.graph
+ graph = onnx_model.graph
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
Optional: allow regex patterns in int8_precision_pattern
to be expanded via expand_node_names_from_patterns
for robustness.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_layer_precision_mapping( | |
onnx_model: onnx.ModelProto, | |
int8_precision_pattern: str | None = None, | |
nodes_to_exclude: list[str] | None = [r"/lm_head"], | |
): | |
graph = onnx_model.graph | |
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) | |
# Collect quantizable weight tensors | |
def get_layer_precision_mapping( | |
onnx_model: onnx.ModelProto, | |
int8_precision_pattern: str | None = None, | |
nodes_to_exclude: list[str] | None = None, | |
): | |
graph = onnx_model.graph | |
nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] | |
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) | |
# Collect quantizable weight tensors |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1379 to 1387, the function
signature uses a mutable default list for nodes_to_exclude which can lead to
cross-call mutation; change the signature to use None as the default
(nodes_to_exclude: list[str] | None = None) and inside the function set
nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is None before calling
expand_node_names_from_patterns; additionally (optional) if
int8_precision_pattern should accept regexes, pass it through
expand_node_names_from_patterns or a similar helper so patterns are expanded
consistently before use.
def rtn( | ||
w: np.ndarray, | ||
s: np.ndarray, | ||
block_size: int, | ||
quantize_axis: int = 0, | ||
zp: np.ndarray = None, | ||
precision_info: dict[str, int] | None = None, | ||
name: str | None = None, | ||
) -> np.ndarray: | ||
"""Quantizes `w` with scale factors `s` via Round-to-Nearest. | ||
Ties are broken by rounding to the nearest even number. | ||
""" | ||
num_bits = get_num_bits(precision_info, name) | ||
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, | ||
# set block_size to the size of the quantize_axis dimension to do per-channel quantization | ||
if block_size == -1 or num_bits == 8: | ||
block_size = w.shape[quantize_axis] | ||
w_padded = _pad(w, block_size, quantize_axis) | ||
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] | ||
if zp is None: | ||
maxq = 2 ** (num_bits - 1) - 1 | ||
minq = -(2 ** (num_bits - 1)) | ||
w_padded = ( | ||
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | ||
.clip(minq, maxq) | ||
.astype(np.int8) | ||
) | ||
else: | ||
maxq = (2**num_bits) - 1 | ||
minq = 0 | ||
w_padded = ( | ||
( | ||
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | ||
+ zp.repeat(num_blocks, axis=quantize_axis) | ||
) | ||
.clip(minq, maxq) | ||
.astype(np.int8) | ||
) | ||
return _depad(w_padded, w.shape, quantize_axis) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve numerical stability for edge cases.
When using zero-point quantization, division by s
could cause issues if s
contains values close to zero (even though clipped by CLIP_MIN). Consider adding a safety check.
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
+
+ # Ensure scale values are not too small to avoid numerical instability
+ if np.any(np.abs(s) < CLIP_MIN):
+ raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}")
+
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
+ zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def rtn( | |
w: np.ndarray, | |
s: np.ndarray, | |
block_size: int, | |
quantize_axis: int = 0, | |
zp: np.ndarray = None, | |
precision_info: dict[str, int] | None = None, | |
name: str | None = None, | |
) -> np.ndarray: | |
"""Quantizes `w` with scale factors `s` via Round-to-Nearest. | |
Ties are broken by rounding to the nearest even number. | |
""" | |
num_bits = get_num_bits(precision_info, name) | |
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, | |
# set block_size to the size of the quantize_axis dimension to do per-channel quantization | |
if block_size == -1 or num_bits == 8: | |
block_size = w.shape[quantize_axis] | |
w_padded = _pad(w, block_size, quantize_axis) | |
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] | |
if zp is None: | |
maxq = 2 ** (num_bits - 1) - 1 | |
minq = -(2 ** (num_bits - 1)) | |
w_padded = ( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
else: | |
maxq = (2**num_bits) - 1 | |
minq = 0 | |
w_padded = ( | |
( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
+ zp.repeat(num_blocks, axis=quantize_axis) | |
) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
return _depad(w_padded, w.shape, quantize_axis) | |
def rtn( | |
w: np.ndarray, | |
s: np.ndarray, | |
block_size: int, | |
quantize_axis: int = 0, | |
zp: np.ndarray = None, | |
precision_info: dict[str, int] | None = None, | |
name: str | None = None, | |
) -> np.ndarray: | |
"""Quantizes `w` with scale factors `s` via Round-to-Nearest. | |
Ties are broken by rounding to the nearest even number. | |
""" | |
num_bits = get_num_bits(precision_info, name) | |
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, | |
# set block_size to the size of the quantize_axis dimension to do per-channel quantization | |
if block_size == -1 or num_bits == 8: | |
block_size = w.shape[quantize_axis] | |
w_padded = _pad(w, block_size, quantize_axis) | |
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] | |
# Ensure scale values are not too small to avoid numerical instability | |
if np.any(np.abs(s) < CLIP_MIN): | |
raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}") | |
if zp is None: | |
maxq = 2 ** (num_bits - 1) - 1 | |
minq = -(2 ** (num_bits - 1)) | |
w_padded = ( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
else: | |
maxq = (2**num_bits) - 1 | |
minq = 0 | |
w_padded = ( | |
( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
zp.repeat(num_blocks, axis=quantize_axis) | |
) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
return _depad(w_padded, w.shape, quantize_axis) |
) | ||
|
||
parser.add_argument( | ||
"--k_quant_mixed", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does k
stands for in k_quant_mixed
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k stands for groupwise/block-wise quantization. We are doing Int4 block-wise + few layers in Int8 per-channel quantization (currently we don't have Int8 block-wise DequantizeLinear node support on trt-rtx side) .
Need to make few more changes as per discussion with Vishal. Will request for review once the updated changes are ready. |
caa9965
to
b6a39be
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)
417-450
: Zero-point dtype should be unsigned.When creating Q/DQ,
zp
should use the unsigned dtype. Passhas_zero_point=True
toget_tensor_dtype
.- tensor_dtype = get_tensor_dtype(num_bits) + tensor_dtype = get_tensor_dtype(num_bits) @@ - zp_tensor = make_gs_zp(name, scale.shape, tensor_dtype) + zp_tensor = make_gs_zp(name, scale.shape, get_tensor_dtype(num_bits, has_zero_point=True))
♻️ Duplicate comments (1)
modelopt/onnx/quantization/quant_utils.py (1)
292-326
: Numerical safety in rtn.Add a guard to avoid tiny
s
after broadcast.- if zp is None: + # Safety: avoid division by tiny scales + if np.any(np.abs(s) < CLIP_MIN): + s = np.where(np.abs(s) < CLIP_MIN, CLIP_MIN, s) + if zp is None: maxq = 2 ** (num_bits - 1) - 1
🧹 Nitpick comments (7)
modelopt/onnx/quantization/graph_utils.py (2)
629-667
: Type hint misuse and minor robustness.
cast("int", weight_tensor.data_type)
uses a string literal as a type param; usecast(int, ...)
or drop cast.- Consider guarding for Gemm
transA
(currently TODO) or explicitly documenting unsupported cases.- gemm_io_type = cast("int", weight_tensor.data_type) + gemm_io_type = cast(int, weight_tensor.data_type)
669-696
: Token-exact matcher looks good. Add fast exit and docstring type.Add
-> bool
and early return for long patterns.-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): +def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool: @@ - for pat in int8_layers: + for pat in int8_layers: needle = tokens(pat) n, m = len(hay), len(needle) + if m == 0 or m > n: + continuemodelopt/onnx/quantization/quant_utils.py (2)
160-177
: update_block_size: clarify contract for w=None.Function dereferences
w
unconditionally whenblock_size==-1
ornum_bits==8
. Either assert or handleNone
.- if block_size is not None and (block_size == -1 or num_bits == 8): - return w.shape[quantize_axis] + if block_size is not None and (block_size == -1 or num_bits == 8): + assert w is not None, "update_block_size requires `w` when per-channel/INT8 is requested" + return w.shape[quantize_axis]
196-198
: Type hint for _next_block_size_multiple.Return type is int in practice; adjust for readability/tools.
-def _next_block_size_multiple(x: float, block_size: int) -> float: +def _next_block_size_multiple(x: float, block_size: int) -> int:modelopt/onnx/quantization/qdq_utils.py (1)
50-63
: ONNX 4-bit enums: confirm availability.
onnx.TensorProto.INT4/UINT4
are relatively new; some environments may lack them. Add a fallback or guard.Option: probe availability once and raise a clear error.
-onnx_dtype_map = { +onnx_dtype_map = { @@ - "INT4": onnx.TensorProto.INT4, - "UINT4": onnx.TensorProto.UINT4, + "INT4": getattr(onnx.TensorProto, "INT4", None), + "UINT4": getattr(onnx.TensorProto, "UINT4", None), } + +# Validate availability early +if onnx_dtype_map["INT4"] is None or onnx_dtype_map["UINT4"] is None: + logger.error("This ONNX build does not support 4-bit dtypes (INT4/UINT4).") + # Optionally: raise, or fall back to INT8 path depending on product decisionWould you like me to wire a clean fallback to INT8 when 4-bit enums are missing?
modelopt/onnx/quantization/int4.py (2)
221-227
: Pattern-vs-name expansion mismatch risk.
nodes_to_exclude
is already expanded to concrete names before callingget_precision_info
, which expects patterns (it internally expands again). This is benign but redundant and can mis-handle special regex chars in names.Two options:
- Pass the original patterns into
get_precision_info
, or- Change
get_precision_info
to accept already-expanded names and skip expansion.
1369-1376
: Doc/type mismatch for int8_layers.Docstring says “comma-separated list” but default shows
[]
and the code expects a string or None. Clarify.- - **int8_layers** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4. - Default: []. + - **int8_layers** (str | None): comma-separated list of layer patterns to quantize to INT8 instead of INT4. + Default: None.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/windows/onnx_ptq/genai_llm/README.md
(1 hunks)examples/windows/onnx_ptq/genai_llm/quantize.py
(3 hunks)modelopt/onnx/quantization/graph_utils.py
(2 hunks)modelopt/onnx/quantization/int4.py
(27 hunks)modelopt/onnx/quantization/qdq_utils.py
(8 hunks)modelopt/onnx/quantization/quant_utils.py
(3 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/windows/onnx_ptq/genai_llm/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/windows/onnx_ptq/genai_llm/quantize.py
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
get_num_bits
(179-193)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (2)
_find_quantizable_weights
(629-666)get_precision_info
(814-843)modelopt/onnx/quantization/quant_utils.py (8)
_pad
(200-214)dq_tensor
(328-344)find_scales
(242-289)get_num_bits
(179-193)quant_tensor
(347-366)rtn
(292-325)update_block_size
(160-176)update_scale_map_for_per_channel_nodes
(231-239)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes
(327-399)insert_qdq_nodes
(402-456)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
num_bits
(180-182)num_bits
(185-187)axis
(279-281)axis
(284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/onnx/quantization/int4.py (2)
562-571
: Per-channel handling is correctly propagated.Good use of
update_scale_map_for_per_channel_nodes
and attribute adjustment in qdq insertion to set axis=1 and drop block_size for INT8/per-channel. This aligns TRT plugin expectations.Also applies to: 1264-1273
244-253
: Gather path: mixed-precision propagation looks correct.You compute per-weight
num_bits
and reuseupdate_block_size
for gather. Nice reuse and consistency across RTN/AWQ paths.Also applies to: 1253-1261
def get_layer_precision_mapping( | ||
onnx_model: onnx.ModelProto, | ||
int8_precision_pattern: str | None = None, | ||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
): | ||
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid mutable default args and normalize exclusions.
Using a list default (nodes_to_exclude=[r"/lm_head"]
) risks cross-call mutation. Also normalize nodes_to_exclude
up-front.
Apply:
-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
):
@@
- nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+ nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
Also applies to: 728-731
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 709-714 (and similarly
at 728-731), avoid using a mutable list as a default for nodes_to_exclude:
change the signature to use nodes_to_exclude: list[str] | None = None, then
inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is
None else list(nodes_to_exclude) to prevent shared-state mutations; immediately
normalize entries (e.g., strip/ensure raw strings or compiled regex as the code
expects) so downstream logic can assume a consistent list type and format.
def get_precision_info( | ||
onnx_model: onnx.ModelProto, | ||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
**kwargs: Any, | ||
): | ||
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits). | ||
This function determines the quantization precision for each weight tensor in the ONNX model, | ||
based on the provided configuration. If mixed quantization is enabled, it uses the layer | ||
precision mapping; otherwise, it returns None. | ||
Args: | ||
onnx_model (onnx.ModelProto): The ONNX model to analyze. | ||
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization. | ||
**kwargs: Additional keyword arguments, such as: | ||
- enable_mixed_quant (bool): Whether to enable mixed quantization. | ||
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8. | ||
Returns: | ||
dict[str, int] | None: A mapping from weight tensor names to their quantization precision, | ||
or None if mixed quantization is not enabled. | ||
""" | ||
precision_info = None | ||
enable_mixed_quant = kwargs.get("enable_mixed_quant", False) | ||
int8_layers = kwargs.get("int8_layers") | ||
if enable_mixed_quant: | ||
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude) | ||
else: | ||
precision_info = None | ||
return precision_info | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Same mutable default issue in public API.
Apply same fix as above.
-def get_precision_info(
+def get_precision_info(
onnx_model: onnx.ModelProto,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
**kwargs: Any,
):
@@
- enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+ enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 814 to 844, the
function get_precision_info declares a mutable default parameter
nodes_to_exclude = [r"/lm_head"], which should be avoided; change the signature
to use nodes_to_exclude: list[str] | None = None (or Optional[list[str]] = None)
and inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude
is None, preserving the original default behavior while avoiding the mutable
default; update any type hints and callers if necessary.
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: | ||
"""Get the appropriate tensor dtype based on precision info and zero point presence. | ||
Args: | ||
num_bits: Number of bits for quantization | ||
has_zero_point: Whether the tensor has a zero point | ||
Returns: | ||
ONNX tensor data type constant | ||
""" | ||
if has_zero_point: | ||
dtype_str = onnx_bit_dtype_unsigned_map[num_bits] | ||
else: | ||
dtype_str = onnx_bit_dtype_signed_map[num_bits] | ||
return onnx_dtype_map[dtype_str] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Validate num_bits and map lookups.
Guard against unexpected num_bits
and derive dtype accordingly.
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
@@
- if has_zero_point:
- dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
- else:
- dtype_str = onnx_bit_dtype_signed_map[num_bits]
+ valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys()
+ if num_bits not in valid_bits:
+ raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
+ dtype_str = (
+ onnx_bit_dtype_unsigned_map[num_bits]
+ if has_zero_point
+ else onnx_bit_dtype_signed_map[num_bits]
+ )
return onnx_dtype_map[dtype_str]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: | |
"""Get the appropriate tensor dtype based on precision info and zero point presence. | |
Args: | |
num_bits: Number of bits for quantization | |
has_zero_point: Whether the tensor has a zero point | |
Returns: | |
ONNX tensor data type constant | |
""" | |
if has_zero_point: | |
dtype_str = onnx_bit_dtype_unsigned_map[num_bits] | |
else: | |
dtype_str = onnx_bit_dtype_signed_map[num_bits] | |
return onnx_dtype_map[dtype_str] | |
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: | |
"""Get the appropriate tensor dtype based on precision info and zero point presence. | |
Args: | |
num_bits: Number of bits for quantization | |
has_zero_point: Whether the tensor has a zero point | |
Returns: | |
ONNX tensor data type constant | |
""" | |
valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys() | |
if num_bits not in valid_bits: | |
raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}") | |
dtype_str = ( | |
onnx_bit_dtype_unsigned_map[num_bits] | |
if has_zero_point | |
else onnx_bit_dtype_signed_map[num_bits] | |
) | |
return onnx_dtype_map[dtype_str] |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 311 to 325, the function
does unguarded lookups on num_bits and mapping dicts which can raise KeyError
for unexpected values; validate that num_bits is one of the supported keys
(e.g., check presence in onnx_bit_dtype_unsigned_map and
onnx_bit_dtype_signed_map) before using it, raise a clear ValueError if
unsupported, and guard the final onnx_dtype_map lookup similarly (or provide a
default/raise with context) so the function always fails fast with a clear
message instead of crashing with a KeyError.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #310 +/- ##
==========================================
- Coverage 73.72% 73.48% -0.24%
==========================================
Files 172 172
Lines 17484 17636 +152
==========================================
+ Hits 12890 12960 +70
- Misses 4594 4676 +82 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
"--enable_mixed_quant", | ||
default=False, | ||
action="store_true", | ||
help="True when we want to use mixed quantization", | ||
) | ||
parser.add_argument( | ||
"--int8_layers", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use just assume mixed_quant enabled if --int8_layers
is non-empty? And remove the --enable_mixed_quant
option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if --int8_layers is specified we will select only those layers to be quantized in int8 which match the patter in int8_layers, else if only --enable_mixed_quant is specified we hardcode select few important layers similar to what some other quantization tools like model builder/llama.cpp are doing.
example:
if "python quantize.py ... -int8_layer="layer.0" --enable_mixed_quant" => all layer.0 will be quantized to int8
else "python quantize.py ... --enable_mixed_quant" => quantize to int8 the first 1/8, last 1/8 and every 3rd layer for below attn and ffn layers.
/model/layers.{i}/attn/qkv_proj/MatMul
/model/layers.{i}/attn/v_proj/MatMul
/model/layers.{i}/mlp/down_proj/MatMul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, make sense now.
default="", | ||
help=( | ||
"Comma-separated list of layer patterns to quantize to INT8 instead of INT4." | ||
"Example: 'layers.0,layers.1,lm_head'" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Overrides default mixed quant strategy. Example: 'layers.0,lm_head
"--enable_mixed_quant", | ||
default=False, | ||
action="store_true", | ||
help="True when we want to use mixed quantization", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: help="Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd layer quantized to INT8; others to INT4.",
…+INT8) ONNX models Signed-off-by: unknown <[email protected]>
…+INT8) ONNX models, refactored changes and handle comments Signed-off-by: unknown <[email protected]>
…+INT8) ONNX models, handle comments and rename functions,variables Signed-off-by: unknown <[email protected]>
b6a39be
to
b069cec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (12)
examples/windows/onnx_ptq/genai_llm/quantize.py (2)
364-366
: Harden the gate that flips mixed-quant on when --layers_8bit is provided.Trim whitespace and only flip when non-empty; also normalize the string once.
- if args.layers_8bit: - args.enable_mixed_quant = True + if args.layers_8bit and args.layers_8bit.strip(): + args.layers_8bit = args.layers_8bit.strip() + args.enable_mixed_quant = True
603-617
: Tighten CLI help and document implied behavior.Clarify that providing --layers_8bit implies mixed quant, and fix example punctuation.
- parser.add_argument( - "--enable_mixed_quant", + parser.add_argument( + "--enable_mixed_quant", default=False, action="store_true", help=( - "Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd attn, " - "mlp layers quantized to 8 bits; others to 4 bits." + "Enable mixed quantization. Default strategy: first 1/8, last 1/8, and every 3rd " + "attn/mlp layer quantized to INT8; others to INT4." ), ) - parser.add_argument( + parser.add_argument( "--layers_8bit", type=str, default="", - help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"), + help=("Overrides default strategy. Comma-separated patterns; providing this implies " + "--enable_mixed_quant. Example: 'layers.0,lm_head'."), )modelopt/onnx/quantization/int4.py (4)
180-186
: Avoid mutable default argument (nodes_to_exclude).Using [] as a default risks cross-call mutation.
-def quantize_rtn( +def quantize_rtn( onnx_model: onnx.ModelProto, block_size: int, dq_only: bool = False, - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, **kwargs: Any, ) -> onnx.ModelProto: @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = nodes_to_exclude or [] + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
439-448
: Same mutable default issue in _quantize_awq_clip.-def _quantize_awq_clip( +def _quantize_awq_clip( onnx_model: onnx.ModelProto, data_reader: CalibrationDataReader, use_external_data_format: bool, calibration_eps: list[str], block_size: int, force_fp16: bool = False, - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, input_shapes_profile: Sequence[dict[str, str]] | None = None, **kwargs: Any, ) -> onnx.ModelProto: @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = nodes_to_exclude or [] + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
959-972
: Same mutable default issue in _quantize_awq_lite.-def _quantize_awq_lite( +def _quantize_awq_lite( onnx_model: onnx.ModelProto, data_reader: CalibrationDataReader, use_external_data_format: bool, calibration_eps: list[str], block_size: int, force_fp16: bool = False, enable_fast_path_using_high_sysram: bool = False, enable_weight_clipping: bool = False, use_zero_point: bool = False, - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, input_shapes_profile: Sequence[dict[str, str]] | None = None, **kwargs: Any, ) -> onnx.ModelProto: @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = nodes_to_exclude or [] + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
604-604
: Use logger.warning instead of deprecated logger.warn.- logger.warn("Augmented ONNX model or external data file was not found") + logger.warning("Augmented ONNX model or external data file was not found")modelopt/onnx/quantization/qdq_utils.py (1)
311-325
: Fail fast for unsupported bit-widths in get_tensor_dtype.Avoid KeyError by validating num_bits and mapping lookups.
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: @@ - if has_zero_point: - dtype_str = onnx_bit_dtype_unsigned_map[num_bits] - else: - dtype_str = onnx_bit_dtype_signed_map[num_bits] - return onnx_dtype_map[dtype_str] + valid_bits = set(onnx_bit_dtype_unsigned_map) & set(onnx_bit_dtype_signed_map) + if num_bits not in valid_bits: + raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}") + dtype_str = ( + onnx_bit_dtype_unsigned_map[num_bits] + if has_zero_point + else onnx_bit_dtype_signed_map[num_bits] + ) + if dtype_str not in onnx_dtype_map: + raise ValueError(f"Unsupported dtype '{dtype_str}' in onnx_dtype_map") + return onnx_dtype_map[dtype_str]modelopt/onnx/quantization/graph_utils.py (4)
707-714
: Avoid mutable default (nodes_to_exclude) and normalize early.-def get_layer_precision_mapping( +def get_layer_precision_mapping( onnx_model: onnx.ModelProto, precision_pattern_8bit: str | None = None, - nodes_to_exclude: list[str] | None = [r"/lm_head"], + nodes_to_exclude: list[str] | None = None, ): @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
812-816
: Same mutable default in get_precision_info.-def get_precision_info( +def get_precision_info( onnx_model: onnx.ModelProto, - nodes_to_exclude: list[str] | None = [r"/lm_head"], + nodes_to_exclude: list[str] | None = None, **kwargs: Any, ): @@ - precision_info = None + nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] + precision_info = None
698-705
: Slightly stricter validation using fullmatch.Prevents accidental partial matches.
- return bool(re.match(pattern, layers_str)) + return bool(re.fullmatch(pattern, layers_str))
844-856
: Consider re.fullmatch for pattern expansion (optional).If callers pass anchored patterns it's fine; fullmatch would make intent explicit.
- matched_node_names.extend([node.name for node in node_list if re.match(pattern, node.name)]) + matched_node_names.extend([node.name for node in node_list if re.fullmatch(pattern, node.name)])modelopt/onnx/quantization/quant_utils.py (1)
160-177
: Guard update_block_size when w is required.When block_size == -1 or num_bits == 8, w must be provided.
def update_block_size( num_bits: int, block_size: int, quantize_axis: int = 0, w: np.ndarray = None ) -> int: @@ - if block_size is not None and (block_size == -1 or num_bits == 8): - return w.shape[quantize_axis] + if block_size is not None and (block_size == -1 or num_bits == 8): + assert w is not None, "Weight tensor 'w' is required to resolve per-channel block size" + return w.shape[quantize_axis] return block_size
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/windows/onnx_ptq/genai_llm/README.md
(1 hunks)examples/windows/onnx_ptq/genai_llm/quantize.py
(3 hunks)modelopt/onnx/quantization/graph_utils.py
(2 hunks)modelopt/onnx/quantization/int4.py
(27 hunks)modelopt/onnx/quantization/qdq_utils.py
(8 hunks)modelopt/onnx/quantization/quant_utils.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/windows/onnx_ptq/genai_llm/README.md
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
get_num_bits
(179-193)
modelopt/onnx/quantization/int4.py (4)
modelopt/onnx/quantization/graph_utils.py (2)
_find_int4_quantizable_weights
(629-666)get_precision_info
(812-841)modelopt/onnx/quantization/ort_utils.py (1)
create_inference_session
(213-245)modelopt/onnx/quantization/quant_utils.py (8)
_pad
(200-214)dq_tensor
(328-344)find_scales
(242-289)get_num_bits
(179-193)quant_tensor
(347-365)reshape_scales_for_per_channel_nodes
(231-239)rtn
(292-325)update_block_size
(160-176)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes
(349-419)insert_qdq_nodes
(422-476)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (7)
examples/windows/onnx_ptq/genai_llm/quantize.py (1)
428-444
: LGTM: plumbing new flags into quantize_int4.Flags are passed through consistently and printed for visibility.
modelopt/onnx/quantization/int4.py (1)
221-231
: LGTM: precision_info propagation and per-channel handling.Per-weight bits flow cleanly into scale computation, Q/DQ insertion, and attribute updates.
Also applies to: 272-279, 293-300, 572-579, 1274-1283, 1290-1297
modelopt/onnx/quantization/qdq_utils.py (2)
349-356
: LGTM: per-weight dtype and attrs wired into DQ/QDQ insertion.Correctly derives dtype from num_bits and updates attributes for per‑channel.
Also applies to: 422-427
58-63
: Verify ONNX/ORT support for INT4/UINT4 — verification failed here (onnx not installed).Confirm CI/runtime ONNX >= 1.16 and that the ORT build exposes TensorProto.INT4/UINT4.
Affects: modelopt/onnx/quantization/qdq_utils.py (lines 58–63, 311–325)
Run locally/CI and paste output:
python - <<'PY' import sys try: import onnx, onnxruntime as ort except Exception as e: print("ERROR:", e, file=sys.stderr); sys.exit(2) print("ONNX version:", onnx.__version__) print("Has INT4:", hasattr(onnx.TensorProto, "INT4")) print("Has UINT4:", hasattr(onnx.TensorProto, "UINT4")) print("ORT version:", ort.__version__) PYmodelopt/onnx/quantization/graph_utils.py (1)
669-696
: LGTM: token‑exact INT8 layer matching avoids substring false positives.Good robustness improvement over naive substring checks.
modelopt/onnx/quantization/quant_utils.py (2)
200-215
: LGTM: fixed quantize_axis bounds check in _pad.Prevents out-of-bounds indexing.
242-290
: LGTM: scale computation clamps to CLIP_MIN and validates zero-points.Improves numerical stability and correctness for affine quantization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
b069cec
to
6201e7b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
modelopt/onnx/quantization/int4.py (3)
149-156
: Keep zero‑point dtype consistent (don’t cast zp to weight dtype).Pass zp unchanged to rtn. Casting zp to w.dtype can lead to subtle rounding differences; zp is an integer offset by definition.
Apply:
- qw = rtn( - np.asarray(w), - s, - block_size_updated, - quantize_axis=gather_quantize_axis, - zp=zp if zp is None else zp.astype(w.dtype), - num_bits=num_bits, - ) + qw = rtn( + np.asarray(w), + s, + block_size_updated, + quantize_axis=gather_quantize_axis, + zp=zp, + num_bits=num_bits, + )
1327-1333
: Avoid mutable default for nodes_to_exclude in public API.Use None default to prevent cross‑call mutation.
Apply:
- nodes_to_exclude: list[str] | None = [r"/lm_head"], + nodes_to_exclude: list[str] | None = None, @@ - nodes_to_exclude = nodes_to_exclude or [] + nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
97-101
: Remove unused CLIP_MIN constant here (dup in quant_utils).It’s duplicated and unused in this module.
Apply:
-# following min-value for clip is taken from AutoAWQ where zero-point based quantization is -# supported and working -CLIP_MIN = 1e-5modelopt/onnx/quantization/graph_utils.py (2)
698-705
: Regex is fine; consider re.fullmatch for clarity (optional).Anchors already enforce full‑string match; switching to re.fullmatch would be clearer but not required.
707-711
: Avoid mutable default list for nodes_to_exclude.Prevent shared state across calls.
Apply:
- nodes_to_exclude: list[str] | None = [r"/lm_head"], + nodes_to_exclude: list[str] | None = None, @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [r"/lm_head"])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/windows/onnx_ptq/genai_llm/README.md
(1 hunks)examples/windows/onnx_ptq/genai_llm/quantize.py
(3 hunks)modelopt/onnx/quantization/graph_utils.py
(2 hunks)modelopt/onnx/quantization/int4.py
(27 hunks)modelopt/onnx/quantization/qdq_utils.py
(8 hunks)modelopt/onnx/quantization/quant_utils.py
(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/windows/onnx_ptq/genai_llm/README.md
- examples/windows/onnx_ptq/genai_llm/quantize.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-22T21:14:46.661Z
Learnt from: gcunhase
PR: NVIDIA/TensorRT-Model-Optimizer#354
File: modelopt/onnx/quantization/qdq_utils.py:874-944
Timestamp: 2025-09-22T21:14:46.661Z
Learning: In ModelOpt's ONNX quantization pipeline, DequantizeLinear nodes always have 3 inputs (data, scale, zero_point) and zero-points are always provided as initializers, not as Constant nodes. This means the `remove_graph_input_q` function can safely assume `dq_node[0].input[2]` exists and is an initializer.
Applied to files:
modelopt/onnx/quantization/int4.py
modelopt/onnx/quantization/qdq_utils.py
🧬 Code graph analysis (3)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (1)
get_precision_info
(812-841)modelopt/onnx/quantization/quant_utils.py (8)
_pad
(200-214)dq_tensor
(328-344)find_scales
(242-289)get_num_bits
(179-193)quant_tensor
(347-365)reshape_scales_for_per_channel_nodes
(231-239)rtn
(292-325)update_block_size
(160-176)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes
(349-419)insert_qdq_nodes
(422-476)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
get_num_bits
(179-193)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
num_bits
(178-180)num_bits
(183-185)axis
(277-279)axis
(282-284)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (19)
modelopt/onnx/quantization/int4.py (6)
124-131
: Correct: block size updated per weight bit-width.Using get_num_bits + update_block_size for Gather is the right call for per-channel INT8.
176-177
: Good: per‑channel scale reshape for Gather.reshape_scales_for_per_channel_nodes correctly handles 8‑bit/per‑channel shaping.
224-231
: Mixed‑precision plumbing LGTM.Threading precision_info through scale/quant and using update_block_size is sound.
283-289
: Fix previously reported Gather DQ scales misuse — verified.You’re now passing gather_s_map (not GEMM scales) into insert_dq_nodes for Gather. Issue addressed.
572-579
: AWQ‑clip DQ insertion honors per‑channel attributes.Using reshape_scales_for_per_channel_nodes + qdq.insert_dq_nodes with precision_info is correct.
976-988
: Mixed‑precision disabled for per‑subgraph AWQ‑lite (intent?).The assertion forbids precision_info with awqlite_run_per_subgraph=True. Confirm that’s intentional and document via a log/warning.
modelopt/onnx/quantization/qdq_utils.py (5)
311-325
: Validate num_bits to fail fast with a clear error.Guard lookups to avoid KeyError for unsupported bit‑widths.
Apply:
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: @@ - if has_zero_point: - dtype_str = onnx_bit_dtype_unsigned_map[num_bits] - else: - dtype_str = onnx_bit_dtype_signed_map[num_bits] - return onnx_dtype_map[dtype_str] + valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys() + if num_bits not in valid_bits: + raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}") + dtype_str = ( + onnx_bit_dtype_unsigned_map[num_bits] + if has_zero_point + else onnx_bit_dtype_signed_map[num_bits] + ) + return onnx_dtype_map[dtype_str]
327-338
: Per‑channel attr update logic is appropriate.Axis override and dropping block_size for 8‑bit/per‑channel looks good.
349-357
: DQ insertion API extension LGTM.precision_info threading and per‑weight dtype/attrs are correct.
422-469
: QDQ insertion handles per‑weight bit‑widths correctly.Using get_tensor_dtype with num_bits and creating zp/outputs accordingly is sound.
945-1014
: Assumption aligns with repo practice: DQ has 3 inputs, zp initializer.remove_graph_input_q relies on zp being an initializer; your insert_qdq/insert_dq_nodes satisfy this.
modelopt/onnx/quantization/graph_utils.py (3)
669-696
: Exact token matching avoids substring false‑positives.Sliding token window fix is correct and robust for layer matching.
812-841
: Same mutable default issue in get_precision_info.Mirror the fix here.
Apply:
- nodes_to_exclude: list[str] | None = [r"/lm_head"], + nodes_to_exclude: list[str] | None = None, @@ - if enable_mixed_quant: - precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude) + if enable_mixed_quant: + precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude or [r"/lm_head"])
735-801
: Heuristic 8‑bit selection is reasonable.Pattern‑based grouping and sampling (first/last eighth + every third) is acceptable as a default.
modelopt/onnx/quantization/quant_utils.py (5)
160-177
: Block size updater LGTM.Handles per‑channel for 8‑bit and -1 sentinel correctly.
200-215
: Axis bounds check fixed.Using 0 <= quantize_axis < len(w.shape) prevents out‑of‑range errors.
242-290
: Scale/zero‑point computation handles edge cases.CLIP_MIN clamp and zp range validation look good.
292-326
: Quantization math is consistent with per‑channel handling.Repeat logic and clipping per num_bits are correct.
231-240
: Per‑channel scale reshape helper LGTM.Conditioning on num_bits==8 or block_size==-1 is appropriate.
6201e7b
to
b069cec
Compare
#310) Signed-off-by: unknown <[email protected]> Signed-off-by: Ye Yu <[email protected]>
…+INT8) ONNX models
What does this PR do?
Type of change: new feature
Overview: Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models
Usage
Testing
Tested sanity/functional testing, MMLU, perf for below models
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Documentation