-
Notifications
You must be signed in to change notification settings - Fork 162
[5506930]Add support in ModelOpt for generating mixed-precision (INT4… #310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds per-weight mixed-precision (INT4/INT8) quantization support, centralizes quant math into shared utilities, extends Q/DQ and QDQ insertion to honor per-weight bit widths, and exposes new CLI options Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as quantize.py (CLI)
participant INT4 as modelopt.onnx.quantization.int4.quantize
participant MAP as graph_utils.get_precision_info
participant PIPE as RTN/AWQ/AWQ-lite
participant QUTIL as quant_utils
participant QDQ as qdq_utils
CLI->>INT4: quantize(..., enable_mixed_quant, int8_layers)
alt enable_mixed_quant == true
INT4->>MAP: get_precision_info(model, int8_layers)
MAP-->>INT4: precision_info (weight -> 4|8)
else
INT4-->>INT4: precision_info = None
end
INT4->>PIPE: run quant path(..., precision_info)
PIPE->>QUTIL: find_scales / quant_tensor / rtn(..., num_bits via precision_info)
QUTIL-->>PIPE: quantized weights, scales, zero-points
PIPE->>QDQ: insert_dq_nodes / insert_qdq_nodes(..., precision_info)
QDQ-->>PIPE: nodes with per-weight dtype (INT4/INT8) and zero-point handling
PIPE-->>INT4: updated ONNX model
INT4-->>CLI: write quantized model
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Pre-merge checks✅ Passed checks (3 passed)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
modelopt/onnx/quantization/int4.py (6)
579-587
: Fix: pass a GraphSurgeon graph to _quantize_gather_nodes in AWQ-clip.
_quantize_gather_nodes
iteratesgraph.nodes
(GraphSurgeon), but heregraph
is an ONNX GraphProto (augmented_model.graph
). This will throw at runtime.Apply:
- gather_w_map, gather_s_map, _ = _quantize_gather_nodes( - graph, + gather_w_map, gather_s_map, _ = _quantize_gather_nodes( + graph_gs, nodes_to_exclude, gather_quantize_axis, gather_block_size, use_zero_point=False, dq_only=True, )
1327-1334
: Per-channel attribute missing for Gather DQ nodes in AWQ-lite.Mirror the per-channel handling for Gather weights too.
qdq.insert_dq_nodes( graph_gs, gather_s_map, quantized_weights=gather_w_map, attributes=gather_dq_node_attributes, zero_points=gather_zp_map if use_zero_point else None, - precision_info=precision_info, + precision_info=precision_info, + is_per_channel=is_per_channel, )
159-166
: Avoid mutable default list for nodes_to_exclude in RTN API.-def quantize_rtn( +def quantize_rtn( onnx_model: onnx.ModelProto, block_size: int, dq_only: bool = False, - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, precision_info: dict[str, int] | None = None, **kwargs: Any, ) -> onnx.ModelProto: @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])
457-468
: Avoid mutable default list for nodes_to_exclude in AWQ-clip API.-def _quantize_awq_clip( +def _quantize_awq_clip( onnx_model: onnx.ModelProto, @@ - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])
1000-1014
: Avoid mutable default list for nodes_to_exclude in AWQ-lite API.-def _quantize_awq_lite( +def _quantize_awq_lite( @@ - nodes_to_exclude: list[str] = [], + nodes_to_exclude: list[str] | None = None, @@ - nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude or [])
1470-1484
: Avoid mutable default list for nodes_to_exclude; add params doc for mixed precision.-def quantize( +def quantize( @@ - nodes_to_exclude: list[str] | None = [r"/lm_head"], + nodes_to_exclude: list[str] | None = None, @@ - logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}") + nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] + logger.debug(f"Excluding nodes matching patterns: {nodes_to_exclude}")Optional: Update the docstring to mention
k_quant_mixed
andint8_layers
behavior.
🧹 Nitpick comments (8)
modelopt/onnx/quantization/quant_utils.py (3)
169-171
: Consider inlining this simple calculation.This helper function is only used once (in
_pad
) and performs a trivial calculation. Consider inlining it directly where used to reduce unnecessary abstraction.-def _next_block_size_multiple(x: float, block_size: int) -> float: - return math.ceil(x / block_size) * block_size - def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: """Pads `w` to next largest multiple of block_size, on quantize_axis.""" assert quantize_axis <= len(w.shape), ( f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" ) if w.shape[quantize_axis] % block_size == 0: return w - pad_width = ( - _next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] - ) + pad_width = math.ceil(w.shape[quantize_axis] / block_size) * block_size - w.shape[quantize_axis]
190-201
: Consider using numpy slicing syntax for all axes.The function handles only 2D arrays but could be made more general and consistent using numpy's ellipsis syntax.
def _depad(w: np.ndarray, orig_shape: tuple, quantize_axis: int = 0) -> np.ndarray: """Depad quantize_axis to original shape.""" if w.shape == orig_shape: return w - ans = None - if quantize_axis == 0: - ans = w[0 : orig_shape[0], ...] - elif quantize_axis == 1: - ans = w[..., 0 : orig_shape[1]] - else: - raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array") - return ans + + if quantize_axis not in [0, 1]: + raise ValueError("Incorrect Quantize-axis: it must be 0 or 1 for a 2D array") + + slices = [slice(None)] * len(w.shape) + slices[quantize_axis] = slice(0, orig_shape[quantize_axis]) + return w[tuple(slices)]
323-338
: Add docstring details about return values.The function's docstring should specify what the three return values represent.
def quant_tensor( w: np.ndarray, block_size: int, quantize_axis: int = 0, alpha: float = 1.0, use_zero_point: bool = False, precision_info: dict[str, int] | None = None, name: str | None = None, ): - """Quantize a tensor using alpha etc. and return the quantized tensor.""" + """Quantize a tensor using alpha etc. and return the quantized tensor. + + Returns: + tuple: A tuple containing: + - wq: The quantized weight tensor (np.ndarray) + - scale: The scale factors used for quantization (np.ndarray) + - zp: The zero-point values (np.ndarray or None if not using zero-point) + """ scale, zp = find_scales( w, block_size, quantize_axis, alpha, use_zero_point, precision_info, name ) wq = rtn(w, scale, block_size, quantize_axis, zp, precision_info, name) return wq, scale, zpexamples/windows/onnx_ptq/genai_llm/quantize.py (1)
599-613
: Consider validating the int8_layers pattern format.The
int8_layers
parameter accepts a comma-separated string of layer patterns, but there's no validation of the format. Consider adding basic validation to catch user errors early.parser.add_argument( "--int8_layers", type=str, default="", help=( "Comma-separated list of layer patterns to quantize to INT8 instead of INT4." "Example: 'layers.0,layers.1,lm_head'" ), ) + +def validate_int8_layers(layers_str: str) -> bool: + """Validate the format of int8_layers string.""" + if not layers_str: + return True + # Basic validation: check for valid characters and structure + import re + pattern = r'^[a-zA-Z0-9_.,\-]+$' + return bool(re.match(pattern, layers_str))Then add validation after parsing:
if args.int8_layers and not validate_int8_layers(args.int8_layers): parser.error("Invalid format for --int8_layers. Use comma-separated layer patterns.")modelopt/onnx/quantization/qdq_utils.py (3)
310-322
: Add input validation for the attributes parameter.The function modifies the attributes dictionary in place, which could cause unexpected side effects if the same dictionary is reused.
def update_attributes(attrib: dict[str, Any] | None = None, is_per_channel: bool = False): """Update attribute dictionary for quantization nodes. If per-channel quantization is enabled, sets the 'axis' attribute to 1 and removes the 'block_size' attribute if present. + + Args: + attrib: Attribute dictionary to update (will be modified in place if not None) + is_per_channel: Whether to enable per-channel quantization + + Returns: + The updated attribute dictionary """ if is_per_channel: if attrib is not None: attrib["axis"] = 1 if "block_size" in attrib: attrib.pop("block_size") return attrib
354-369
: Consider extracting tensor dtype selection logic to a helper function.The tensor dtype selection logic is duplicated in both
insert_dq_nodes
andinsert_qdq_nodes
. Consider extracting it to reduce duplication and improve maintainability.+def get_tensor_dtype(precision_info: dict[str, int] | None, name: str, has_zero_point: bool) -> int: + """Get the appropriate tensor dtype based on precision info and zero point presence. + + Args: + precision_info: Dictionary mapping tensor names to bit widths + name: Name of the tensor + has_zero_point: Whether the tensor has a zero point + + Returns: + ONNX tensor data type constant + """ + if precision_info and name in precision_info: + bit_width = precision_info[name] + if has_zero_point: + dtype_str = onnx_bit_dtype_unsigned_map[bit_width] + else: + dtype_str = onnx_bit_dtype_signed_map[bit_width] + return onnx_dtype_map[dtype_str] + else: + # Default to 4-bit + return onnx.TensorProto.INT4 if not has_zero_point else onnx.TensorProto.UINT4 def _insert_helper( name: str, wq: np.ndarray, scale: np.ndarray, dq_nodes: dict[str, gs.Node], zp: np.ndarray, attrs: dict[str, Any] | None = None, precision_info: dict[str, int] | None = None, is_per_channel: bool = False, ): attrib = dict(attrs) if attrs is not None else None - if precision_info and name in precision_info: - tensor_dtype = ( - onnx_dtype_map[onnx_bit_dtype_signed_map[precision_info[name]]] - if zp is None - else onnx_dtype_map[onnx_bit_dtype_unsigned_map[precision_info[name]]] - ) - # do per-channel quantization for int8 as no support for int8 block-wise dq node - if precision_info[name] == 8: - # reshape scale to be per-channel - scale = scale.reshape(-1) - attrib = update_attributes(attrib, True) - else: - tensor_dtype = onnx.TensorProto.INT4 if zp is None else onnx.TensorProto.UINT4 + tensor_dtype = get_tensor_dtype(precision_info, name, zp is not None) + + # do per-channel quantization for int8 as no support for int8 block-wise dq node + if precision_info and name in precision_info and precision_info[name] == 8: + # reshape scale to be per-channel + scale = scale.reshape(-1) + attrib = update_attributes(attrib, True)
361-364
: Validate per‐channel scale size before flattening. Althoughscale.reshape(-1)
preserves all elements, add an explicit check (e.g.assert scale.size == wq.shape[1]
) before callingreshape
to guard against mismatches when quantizing along axis 1.modelopt/onnx/quantization/int4.py (1)
620-624
: Use logger.warning instead of deprecated logger.warn.- logger.warn("Augmented ONNX model or external data file was not found") + logger.warning("Augmented ONNX model or external data file was not found")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
examples/windows/onnx_ptq/genai_llm/quantize.py
(3 hunks)modelopt/onnx/quantization/int4.py
(29 hunks)modelopt/onnx/quantization/qdq_utils.py
(8 hunks)modelopt/onnx/quantization/quant_utils.py
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/quant_utils.py (5)
_pad
(173-187)dq_tensor
(297-320)find_scales
(204-252)quant_tensor
(323-337)rtn
(255-294)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes
(324-406)insert_qdq_nodes
(409-463)modelopt/onnx/quantization/graph_utils.py (1)
expand_node_names_from_patterns
(628-639)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
num_bits
(180-182)num_bits
(185-187)axis
(279-281)axis
(284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (10)
modelopt/onnx/quantization/quant_utils.py (3)
160-167
: LGTM! Clean implementation of bit-width selection logic.The function provides a clear and simple way to get the number of bits for quantization based on the precision info dictionary, with a sensible default of 4 bits.
173-188
: LGTM! Proper padding implementation with clear assertions.The padding logic correctly handles different axes and includes appropriate assertions for validation.
297-321
: LGTM! Clear dequantization implementation.The dequantization logic correctly handles both signed and unsigned cases with proper padding and depadding.
examples/windows/onnx_ptq/genai_llm/quantize.py (2)
368-369
: LGTM! Status output correctly includes the new k_quant_mixed flag.The logging statement properly shows the new mixed-precision quantization option.
438-439
: LGTM! Proper propagation of mixed-precision parameters.The new parameters are correctly passed to the quantize_int4 function.
modelopt/onnx/quantization/qdq_utils.py (2)
57-61
: LGTM! Proper extension of dtype mappings for 4-bit support.The addition of INT4 and UINT4 to the dtype maps enables proper 4-bit quantization support.
432-440
: LGTM! Consistent handling of per-weight precision in QDQ nodes.The implementation correctly mirrors the DQ node logic for handling per-weight precision information.
modelopt/onnx/quantization/int4.py (3)
206-209
: LGTM: precision_info is correctly threaded in RTN path.Scale computation, quantization, and Q/DQ insertion honor per-weight precision and per-channel when applicable.
Also applies to: 240-243, 248-255, 259-265
315-322
: LGTM: precision_info propagation in AWQ-clip compute path.The clip search and final quantization use per-weight precision consistently.
Also applies to: 362-366, 526-528, 552-555
705-707
: LGTM: precision_info propagation in AWQ-lite compute path.Both per-node and per-subgraph scale searches respect mixed precision; quantization/dequantization calls pass the context correctly.
Also applies to: 762-771, 782-783, 882-884, 939-952, 1112-1131, 1182-1188
dq_node_attributes = {"axis": 0, "block_size": block_size} | ||
|
||
qdq.insert_dq_nodes( | ||
graph_gs, | ||
scales, | ||
quantized_weights=gemm_weights_quantized, | ||
attributes=dq_node_attributes, | ||
zero_points=zero_points if use_zero_point else None, | ||
precision_info=precision_info, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Per-channel attribute not propagated in AWQ-lite.
When block_size == -1
, you set per-channel in RTN and AWQ-clip, but not here. TRT plugins require this attribute for correct dequantization.
Apply:
dq_node_attributes = {"axis": 0, "block_size": block_size}
+ is_per_channel = block_size == -1
qdq.insert_dq_nodes(
graph_gs,
scales,
quantized_weights=gemm_weights_quantized,
attributes=dq_node_attributes,
zero_points=zero_points if use_zero_point else None,
- precision_info=precision_info,
+ precision_info=precision_info,
+ is_per_channel=is_per_channel,
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
dq_node_attributes = {"axis": 0, "block_size": block_size} | |
qdq.insert_dq_nodes( | |
graph_gs, | |
scales, | |
quantized_weights=gemm_weights_quantized, | |
attributes=dq_node_attributes, | |
zero_points=zero_points if use_zero_point else None, | |
precision_info=precision_info, | |
) | |
dq_node_attributes = {"axis": 0, "block_size": block_size} | |
is_per_channel = block_size == -1 | |
qdq.insert_dq_nodes( | |
graph_gs, | |
scales, | |
quantized_weights=gemm_weights_quantized, | |
attributes=dq_node_attributes, | |
zero_points=zero_points if use_zero_point else None, | |
precision_info=precision_info, | |
is_per_channel=is_per_channel, | |
) |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1311-1320, the
dq_node_attributes only contains "axis" and "block_size" for AWQ-lite; when
block_size == -1 you must also set the per-channel flag (as done for RTN and
AWQ-clip) so TRT plugins get correct dequantization info. Modify the
construction of dq_node_attributes to include "per_channel": True when
block_size == -1 (otherwise keep it False or omit), then pass that attributes
dict into qdq.insert_dq_nodes so the per-channel attribute is propagated.
modelopt/onnx/quantization/int4.py
Outdated
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): | ||
"""Check if layer should be quantized to INT8. | ||
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. | ||
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. | ||
To match these, we: | ||
- Remove the leading slash from the node name. | ||
- Replace all '/' with '.' to match the naming convention of the initializer. | ||
This allows us to correctly identify which weights should be quantized to INT8. | ||
""" | ||
if not int8_layers: | ||
return False | ||
normalized_patterns = [] | ||
for pattern in int8_layers: | ||
p = pattern.lstrip("/") | ||
p = p.replace("/", ".") | ||
normalized_patterns.append(p) | ||
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid substring false-positives when matching INT8 layers.
in
substring matching can misclassify e.g., layers.1
matches layers.11
. Use token-exact matching.
-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
+def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool:
@@
- normalized_patterns = []
- for pattern in int8_layers:
- p = pattern.lstrip("/")
- p = p.replace("/", ".")
- normalized_patterns.append(p)
- return any(norm_pattern in layer_name for norm_pattern in normalized_patterns)
+ # Normalize both to dot-delimited tokens and require exact token sequence match.
+ def tokens(s: str) -> list[str]:
+ return s.lstrip("/").replace("/", ".").split(".")
+ hay = tokens(layer_name)
+ for pat in int8_layers:
+ needle = tokens(pat)
+ n, m = len(hay), len(needle)
+ for i in range(0, n - m + 1):
+ if hay[i : i + m] == needle:
+ return True
+ return False
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): | |
"""Check if layer should be quantized to INT8. | |
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. | |
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. | |
To match these, we: | |
- Remove the leading slash from the node name. | |
- Replace all '/' with '.' to match the naming convention of the initializer. | |
This allows us to correctly identify which weights should be quantized to INT8. | |
""" | |
if not int8_layers: | |
return False | |
normalized_patterns = [] | |
for pattern in int8_layers: | |
p = pattern.lstrip("/") | |
p = p.replace("/", ".") | |
normalized_patterns.append(p) | |
return any(norm_pattern in layer_name for norm_pattern in normalized_patterns) | |
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool: | |
"""Check if layer should be quantized to INT8. | |
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. | |
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. | |
To match these, we: | |
- Remove the leading slash from the node name. | |
- Replace all '/' with '.' to match the naming convention of the initializer. | |
This allows us to correctly identify which weights should be quantized to INT8. | |
""" | |
if not int8_layers: | |
return False | |
# Normalize both to dot-delimited tokens and require exact token sequence match. | |
def tokens(s: str) -> list[str]: | |
return s.lstrip("/").replace("/", ".").split(".") | |
hay = tokens(layer_name) | |
for pat in int8_layers: | |
needle = tokens(pat) | |
n, m = len(hay), len(needle) | |
for i in range(0, n - m + 1): | |
if hay[i : i + m] == needle: | |
return True | |
return False |
modelopt/onnx/quantization/int4.py
Outdated
def get_layer_precision_mapping( | ||
onnx_model: onnx.ModelProto, | ||
int8_precision_pattern: str | None = None, | ||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
): | ||
graph = onnx_model.graph | ||
|
||
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) | ||
# Collect quantizable weight tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Don’t use a mutable list as a default argument.
nodes_to_exclude: list[str] | None = [r"/lm_head"]
risks cross-call mutation.
-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
):
- graph = onnx_model.graph
+ graph = onnx_model.graph
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
Optional: allow regex patterns in int8_precision_pattern
to be expanded via expand_node_names_from_patterns
for robustness.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_layer_precision_mapping( | |
onnx_model: onnx.ModelProto, | |
int8_precision_pattern: str | None = None, | |
nodes_to_exclude: list[str] | None = [r"/lm_head"], | |
): | |
graph = onnx_model.graph | |
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) | |
# Collect quantizable weight tensors | |
def get_layer_precision_mapping( | |
onnx_model: onnx.ModelProto, | |
int8_precision_pattern: str | None = None, | |
nodes_to_exclude: list[str] | None = None, | |
): | |
graph = onnx_model.graph | |
nodes_to_exclude = nodes_to_exclude or [r"/lm_head"] | |
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) | |
# Collect quantizable weight tensors |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/int4.py around lines 1379 to 1387, the function
signature uses a mutable default list for nodes_to_exclude which can lead to
cross-call mutation; change the signature to use None as the default
(nodes_to_exclude: list[str] | None = None) and inside the function set
nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is None before calling
expand_node_names_from_patterns; additionally (optional) if
int8_precision_pattern should accept regexes, pass it through
expand_node_names_from_patterns or a similar helper so patterns are expanded
consistently before use.
def rtn( | ||
w: np.ndarray, | ||
s: np.ndarray, | ||
block_size: int, | ||
quantize_axis: int = 0, | ||
zp: np.ndarray = None, | ||
precision_info: dict[str, int] | None = None, | ||
name: str | None = None, | ||
) -> np.ndarray: | ||
"""Quantizes `w` with scale factors `s` via Round-to-Nearest. | ||
Ties are broken by rounding to the nearest even number. | ||
""" | ||
num_bits = get_num_bits(precision_info, name) | ||
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, | ||
# set block_size to the size of the quantize_axis dimension to do per-channel quantization | ||
if block_size == -1 or num_bits == 8: | ||
block_size = w.shape[quantize_axis] | ||
w_padded = _pad(w, block_size, quantize_axis) | ||
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] | ||
if zp is None: | ||
maxq = 2 ** (num_bits - 1) - 1 | ||
minq = -(2 ** (num_bits - 1)) | ||
w_padded = ( | ||
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | ||
.clip(minq, maxq) | ||
.astype(np.int8) | ||
) | ||
else: | ||
maxq = (2**num_bits) - 1 | ||
minq = 0 | ||
w_padded = ( | ||
( | ||
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | ||
+ zp.repeat(num_blocks, axis=quantize_axis) | ||
) | ||
.clip(minq, maxq) | ||
.astype(np.int8) | ||
) | ||
return _depad(w_padded, w.shape, quantize_axis) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve numerical stability for edge cases.
When using zero-point quantization, division by s
could cause issues if s
contains values close to zero (even though clipped by CLIP_MIN). Consider adding a safety check.
def rtn(
w: np.ndarray,
s: np.ndarray,
block_size: int,
quantize_axis: int = 0,
zp: np.ndarray = None,
precision_info: dict[str, int] | None = None,
name: str | None = None,
) -> np.ndarray:
"""Quantizes `w` with scale factors `s` via Round-to-Nearest.
Ties are broken by rounding to the nearest even number.
"""
num_bits = get_num_bits(precision_info, name)
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node,
# set block_size to the size of the quantize_axis dimension to do per-channel quantization
if block_size == -1 or num_bits == 8:
block_size = w.shape[quantize_axis]
w_padded = _pad(w, block_size, quantize_axis)
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis]
+
+ # Ensure scale values are not too small to avoid numerical instability
+ if np.any(np.abs(s) < CLIP_MIN):
+ raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}")
+
if zp is None:
maxq = 2 ** (num_bits - 1) - 1
minq = -(2 ** (num_bits - 1))
w_padded = (
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
.clip(minq, maxq)
.astype(np.int8)
)
else:
maxq = (2**num_bits) - 1
minq = 0
w_padded = (
(
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis))
+ zp.repeat(num_blocks, axis=quantize_axis)
)
.clip(minq, maxq)
.astype(np.int8)
)
return _depad(w_padded, w.shape, quantize_axis)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def rtn( | |
w: np.ndarray, | |
s: np.ndarray, | |
block_size: int, | |
quantize_axis: int = 0, | |
zp: np.ndarray = None, | |
precision_info: dict[str, int] | None = None, | |
name: str | None = None, | |
) -> np.ndarray: | |
"""Quantizes `w` with scale factors `s` via Round-to-Nearest. | |
Ties are broken by rounding to the nearest even number. | |
""" | |
num_bits = get_num_bits(precision_info, name) | |
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, | |
# set block_size to the size of the quantize_axis dimension to do per-channel quantization | |
if block_size == -1 or num_bits == 8: | |
block_size = w.shape[quantize_axis] | |
w_padded = _pad(w, block_size, quantize_axis) | |
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] | |
if zp is None: | |
maxq = 2 ** (num_bits - 1) - 1 | |
minq = -(2 ** (num_bits - 1)) | |
w_padded = ( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
else: | |
maxq = (2**num_bits) - 1 | |
minq = 0 | |
w_padded = ( | |
( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
+ zp.repeat(num_blocks, axis=quantize_axis) | |
) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
return _depad(w_padded, w.shape, quantize_axis) | |
def rtn( | |
w: np.ndarray, | |
s: np.ndarray, | |
block_size: int, | |
quantize_axis: int = 0, | |
zp: np.ndarray = None, | |
precision_info: dict[str, int] | None = None, | |
name: str | None = None, | |
) -> np.ndarray: | |
"""Quantizes `w` with scale factors `s` via Round-to-Nearest. | |
Ties are broken by rounding to the nearest even number. | |
""" | |
num_bits = get_num_bits(precision_info, name) | |
# If block_size == -1 and num_bits == 8 as no support for int8 block-wise dq node, | |
# set block_size to the size of the quantize_axis dimension to do per-channel quantization | |
if block_size == -1 or num_bits == 8: | |
block_size = w.shape[quantize_axis] | |
w_padded = _pad(w, block_size, quantize_axis) | |
num_blocks = w_padded.shape[quantize_axis] // s.shape[quantize_axis] | |
# Ensure scale values are not too small to avoid numerical instability | |
if np.any(np.abs(s) < CLIP_MIN): | |
raise ValueError(f"Scale values too small for stable quantization: min={np.min(np.abs(s))}") | |
if zp is None: | |
maxq = 2 ** (num_bits - 1) - 1 | |
minq = -(2 ** (num_bits - 1)) | |
w_padded = ( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
else: | |
maxq = (2**num_bits) - 1 | |
minq = 0 | |
w_padded = ( | |
( | |
np.rint(w_padded / s.repeat(num_blocks, axis=quantize_axis)) | |
zp.repeat(num_blocks, axis=quantize_axis) | |
) | |
.clip(minq, maxq) | |
.astype(np.int8) | |
) | |
return _depad(w_padded, w.shape, quantize_axis) |
) | ||
|
||
parser.add_argument( | ||
"--k_quant_mixed", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does k
stands for in k_quant_mixed
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k stands for groupwise/block-wise quantization. We are doing Int4 block-wise + few layers in Int8 per-channel quantization (currently we don't have Int8 block-wise DequantizeLinear node support on trt-rtx side) .
Need to make few more changes as per discussion with Vishal. Will request for review once the updated changes are ready. |
…+INT8) ONNX models Signed-off-by: unknown <[email protected]>
…+INT8) ONNX models, refactored changes and handle comments Signed-off-by: unknown <[email protected]>
caa9965
to
b6a39be
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)
417-450
: Zero-point dtype should be unsigned.When creating Q/DQ,
zp
should use the unsigned dtype. Passhas_zero_point=True
toget_tensor_dtype
.- tensor_dtype = get_tensor_dtype(num_bits) + tensor_dtype = get_tensor_dtype(num_bits) @@ - zp_tensor = make_gs_zp(name, scale.shape, tensor_dtype) + zp_tensor = make_gs_zp(name, scale.shape, get_tensor_dtype(num_bits, has_zero_point=True))
♻️ Duplicate comments (1)
modelopt/onnx/quantization/quant_utils.py (1)
292-326
: Numerical safety in rtn.Add a guard to avoid tiny
s
after broadcast.- if zp is None: + # Safety: avoid division by tiny scales + if np.any(np.abs(s) < CLIP_MIN): + s = np.where(np.abs(s) < CLIP_MIN, CLIP_MIN, s) + if zp is None: maxq = 2 ** (num_bits - 1) - 1
🧹 Nitpick comments (7)
modelopt/onnx/quantization/graph_utils.py (2)
629-667
: Type hint misuse and minor robustness.
cast("int", weight_tensor.data_type)
uses a string literal as a type param; usecast(int, ...)
or drop cast.- Consider guarding for Gemm
transA
(currently TODO) or explicitly documenting unsupported cases.- gemm_io_type = cast("int", weight_tensor.data_type) + gemm_io_type = cast(int, weight_tensor.data_type)
669-696
: Token-exact matcher looks good. Add fast exit and docstring type.Add
-> bool
and early return for long patterns.-def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): +def should_quantize_to_int8(layer_name: str, int8_layers: list[str]) -> bool: @@ - for pat in int8_layers: + for pat in int8_layers: needle = tokens(pat) n, m = len(hay), len(needle) + if m == 0 or m > n: + continuemodelopt/onnx/quantization/quant_utils.py (2)
160-177
: update_block_size: clarify contract for w=None.Function dereferences
w
unconditionally whenblock_size==-1
ornum_bits==8
. Either assert or handleNone
.- if block_size is not None and (block_size == -1 or num_bits == 8): - return w.shape[quantize_axis] + if block_size is not None and (block_size == -1 or num_bits == 8): + assert w is not None, "update_block_size requires `w` when per-channel/INT8 is requested" + return w.shape[quantize_axis]
196-198
: Type hint for _next_block_size_multiple.Return type is int in practice; adjust for readability/tools.
-def _next_block_size_multiple(x: float, block_size: int) -> float: +def _next_block_size_multiple(x: float, block_size: int) -> int:modelopt/onnx/quantization/qdq_utils.py (1)
50-63
: ONNX 4-bit enums: confirm availability.
onnx.TensorProto.INT4/UINT4
are relatively new; some environments may lack them. Add a fallback or guard.Option: probe availability once and raise a clear error.
-onnx_dtype_map = { +onnx_dtype_map = { @@ - "INT4": onnx.TensorProto.INT4, - "UINT4": onnx.TensorProto.UINT4, + "INT4": getattr(onnx.TensorProto, "INT4", None), + "UINT4": getattr(onnx.TensorProto, "UINT4", None), } + +# Validate availability early +if onnx_dtype_map["INT4"] is None or onnx_dtype_map["UINT4"] is None: + logger.error("This ONNX build does not support 4-bit dtypes (INT4/UINT4).") + # Optionally: raise, or fall back to INT8 path depending on product decisionWould you like me to wire a clean fallback to INT8 when 4-bit enums are missing?
modelopt/onnx/quantization/int4.py (2)
221-227
: Pattern-vs-name expansion mismatch risk.
nodes_to_exclude
is already expanded to concrete names before callingget_precision_info
, which expects patterns (it internally expands again). This is benign but redundant and can mis-handle special regex chars in names.Two options:
- Pass the original patterns into
get_precision_info
, or- Change
get_precision_info
to accept already-expanded names and skip expansion.
1369-1376
: Doc/type mismatch for int8_layers.Docstring says “comma-separated list” but default shows
[]
and the code expects a string or None. Clarify.- - **int8_layers** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4. - Default: []. + - **int8_layers** (str | None): comma-separated list of layer patterns to quantize to INT8 instead of INT4. + Default: None.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/windows/onnx_ptq/genai_llm/README.md
(1 hunks)examples/windows/onnx_ptq/genai_llm/quantize.py
(3 hunks)modelopt/onnx/quantization/graph_utils.py
(2 hunks)modelopt/onnx/quantization/int4.py
(27 hunks)modelopt/onnx/quantization/qdq_utils.py
(8 hunks)modelopt/onnx/quantization/quant_utils.py
(3 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/windows/onnx_ptq/genai_llm/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/windows/onnx_ptq/genai_llm/quantize.py
🧰 Additional context used
🧬 Code graph analysis (3)
modelopt/onnx/quantization/qdq_utils.py (1)
modelopt/onnx/quantization/quant_utils.py (1)
get_num_bits
(179-193)
modelopt/onnx/quantization/int4.py (3)
modelopt/onnx/quantization/graph_utils.py (2)
_find_quantizable_weights
(629-666)get_precision_info
(814-843)modelopt/onnx/quantization/quant_utils.py (8)
_pad
(200-214)dq_tensor
(328-344)find_scales
(242-289)get_num_bits
(179-193)quant_tensor
(347-366)rtn
(292-325)update_block_size
(160-176)update_scale_map_for_per_channel_nodes
(231-239)modelopt/onnx/quantization/qdq_utils.py (2)
insert_dq_nodes
(327-399)insert_qdq_nodes
(402-456)
modelopt/onnx/quantization/quant_utils.py (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
num_bits
(180-182)num_bits
(185-187)axis
(279-281)axis
(284-286)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
modelopt/onnx/quantization/int4.py (2)
562-571
: Per-channel handling is correctly propagated.Good use of
update_scale_map_for_per_channel_nodes
and attribute adjustment in qdq insertion to set axis=1 and drop block_size for INT8/per-channel. This aligns TRT plugin expectations.Also applies to: 1264-1273
244-253
: Gather path: mixed-precision propagation looks correct.You compute per-weight
num_bits
and reuseupdate_block_size
for gather. Nice reuse and consistency across RTN/AWQ paths.Also applies to: 1253-1261
def validate_int8_layers(layers_str: str) -> bool: | ||
"""Validate the format of int8_layers string.""" | ||
if not layers_str: | ||
return True | ||
# Basic validation: check for valid characters and structure | ||
import re | ||
|
||
pattern = r"^[a-zA-Z0-9_.,\-]$" | ||
return bool(re.match(pattern, layers_str)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix regex validation to accept real layer patterns.
The current pattern r"^[a-zA-Z0-9_.,\-]$
only matches a single character and disallows /
. This will reject valid inputs like /model/layers.13/attn/qkv_proj/MatMul
.
Apply:
-def validate_int8_layers(layers_str: str) -> bool:
+def validate_int8_layers(layers_str: str) -> bool:
@@
- pattern = r"^[a-zA-Z0-9_.,\-]$"
- return bool(re.match(pattern, layers_str))
+ # Allow comma-separated list of path-like tokens
+ pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
+ return bool(re.match(pattern, layers_str))
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def validate_int8_layers(layers_str: str) -> bool: | |
"""Validate the format of int8_layers string.""" | |
if not layers_str: | |
return True | |
# Basic validation: check for valid characters and structure | |
import re | |
pattern = r"^[a-zA-Z0-9_.,\-]$" | |
return bool(re.match(pattern, layers_str)) | |
def validate_int8_layers(layers_str: str) -> bool: | |
"""Validate the format of int8_layers string.""" | |
if not layers_str: | |
return True | |
# Basic validation: check for valid characters and structure | |
import re | |
# Allow comma-separated list of path-like tokens | |
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$' | |
return bool(re.match(pattern, layers_str)) |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 698 to 707, the regex
only matches a single character and disallows "/" so valid layer paths are
rejected; update the validation to allow one-or-more characters and include "/"
(and keep letters, digits, underscore, dot, comma and hyphen), e.g. replace the
pattern with a character-class that includes "/" and a quantifier (and ensure
hyphen is either escaped or last), and use re.fullmatch instead of re.match so
the entire string is validated (e.g. pattern like "^[A-Za-z0-9_/.,-]+$" with
re.fullmatch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look into this suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
@i-riyad I've analyzed the suggestion and it's correct - there's a critical bug in the regex validation.
The CodeRabbit suggestion is absolutely correct. The current regex has two major flaws:
- Single character matching: The pattern
^[a-zA-Z0-9_.,\-]$
only matches exactly one character due to the absence of quantifiers - Missing forward slash: Real ONNX node names like
/model/layers.13/attn/qkv_proj/MatMul
contain forward slashes which aren't allowed
The suggested fix properly addresses both issues by:
- Adding
+
quantifier to match one or more characters - Including
/
in the character class - Supporting comma-separated values with optional whitespace
However, I'd recommend a small improvement to make the validation more robust:
def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Allow comma-separated list of path-like tokens
# Pattern allows: letters, digits, underscore, dot, slash, hyphen
# Supports comma-separated values with optional whitespace
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
return bool(re.fullmatch(pattern, layers_str))
Using re.fullmatch()
instead of re.match()
provides clearer intent (matching the entire string) and is slightly more explicit, though functionally equivalent when using ^
and $
anchors.
The suggestion should definitely be implemented to fix this validation bug.
🏁 Script executed:
#!/bin/bash
# Let's examine the current validate_int8_layers function and test the regex behavior
cd /tmp
cat > test_regex.py << 'EOF'
import re
# Current implementation
def validate_int8_layers_current(layers_str: str) -> bool:
if not layers_str:
return True
pattern = r"^[a-zA-Z0-9_.,\-]$"
return bool(re.match(pattern, layers_str))
# Suggested implementation
def validate_int8_layers_suggested(layers_str: str) -> bool:
if not layers_str:
return True
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
return bool(re.match(pattern, layers_str))
# Test cases
test_cases = [
"/model/layers.13/attn/qkv_proj/MatMul",
"/model/layers.0/attn/v_proj/MatMul,/model/layers.1/mlp/down_proj/MatMul",
"model.layers.13.attn.qkv_proj.MatMul.weight",
"a", # single character - should pass for current
"", # empty string - should pass for both
"/invalid/path with spaces/MatMul", # spaces in path - should fail
"valid,path", # simple comma-separated - should pass for suggested
]
print("Testing current vs suggested implementation:")
print("=" * 60)
for test in test_cases:
current_result = validate_int8_layers_current(test)
suggested_result = validate_int8_layers_suggested(test)
print(f"Input: '{test}'")
print(f" Current: {current_result}")
print(f" Suggested: {suggested_result}")
print()
EOF
python test_regex.py
Length of output: 634
Perfect! The test results confirm the analysis. The current implementation is severely broken - it only accepts single characters or empty strings, making it useless for validating real layer patterns.
Key findings:
- ✅ Current regex rejects ALL valid layer patterns like
/model/layers.13/attn/qkv_proj/MatMul
- ✅ Suggested fix correctly accepts valid patterns while rejecting invalid ones (like paths with spaces)
- ✅ Both handle empty strings correctly
- ✅ Suggested pattern properly supports comma-separated values
Recommendation: Implement the CodeRabbit suggestion immediately. This is a critical bug that makes the validation function completely non-functional for its intended purpose.
Here's the corrected implementation:
def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re
# Allow comma-separated list of path-like tokens
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
return bool(re.match(pattern, layers_str))
The current regex would cause the mixed-precision quantization feature to fail silently or reject all user-provided layer specifications, which would be a significant usability issue.
def get_layer_precision_mapping( | ||
onnx_model: onnx.ModelProto, | ||
int8_precision_pattern: str | None = None, | ||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
): | ||
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid mutable default args and normalize exclusions.
Using a list default (nodes_to_exclude=[r"/lm_head"]
) risks cross-call mutation. Also normalize nodes_to_exclude
up-front.
Apply:
-def get_layer_precision_mapping(
+def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
):
@@
- nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+ nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
Also applies to: 728-731
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 709-714 (and similarly
at 728-731), avoid using a mutable list as a default for nodes_to_exclude:
change the signature to use nodes_to_exclude: list[str] | None = None, then
inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude is
None else list(nodes_to_exclude) to prevent shared-state mutations; immediately
normalize entries (e.g., strip/ensure raw strings or compiled regex as the code
expects) so downstream logic can assume a consistent list type and format.
def get_precision_info( | ||
onnx_model: onnx.ModelProto, | ||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||
**kwargs: Any, | ||
): | ||
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits). | ||
This function determines the quantization precision for each weight tensor in the ONNX model, | ||
based on the provided configuration. If mixed quantization is enabled, it uses the layer | ||
precision mapping; otherwise, it returns None. | ||
Args: | ||
onnx_model (onnx.ModelProto): The ONNX model to analyze. | ||
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization. | ||
**kwargs: Additional keyword arguments, such as: | ||
- enable_mixed_quant (bool): Whether to enable mixed quantization. | ||
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8. | ||
Returns: | ||
dict[str, int] | None: A mapping from weight tensor names to their quantization precision, | ||
or None if mixed quantization is not enabled. | ||
""" | ||
precision_info = None | ||
enable_mixed_quant = kwargs.get("enable_mixed_quant", False) | ||
int8_layers = kwargs.get("int8_layers") | ||
if enable_mixed_quant: | ||
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude) | ||
else: | ||
precision_info = None | ||
return precision_info | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Same mutable default issue in public API.
Apply same fix as above.
-def get_precision_info(
+def get_precision_info(
onnx_model: onnx.ModelProto,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
**kwargs: Any,
):
@@
- enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+ enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/graph_utils.py around lines 814 to 844, the
function get_precision_info declares a mutable default parameter
nodes_to_exclude = [r"/lm_head"], which should be avoided; change the signature
to use nodes_to_exclude: list[str] | None = None (or Optional[list[str]] = None)
and inside the function set nodes_to_exclude = [r"/lm_head"] if nodes_to_exclude
is None, preserving the original default behavior while avoiding the mutable
default; update any type hints and callers if necessary.
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: | ||
"""Get the appropriate tensor dtype based on precision info and zero point presence. | ||
Args: | ||
num_bits: Number of bits for quantization | ||
has_zero_point: Whether the tensor has a zero point | ||
Returns: | ||
ONNX tensor data type constant | ||
""" | ||
if has_zero_point: | ||
dtype_str = onnx_bit_dtype_unsigned_map[num_bits] | ||
else: | ||
dtype_str = onnx_bit_dtype_signed_map[num_bits] | ||
return onnx_dtype_map[dtype_str] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Validate num_bits and map lookups.
Guard against unexpected num_bits
and derive dtype accordingly.
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int:
@@
- if has_zero_point:
- dtype_str = onnx_bit_dtype_unsigned_map[num_bits]
- else:
- dtype_str = onnx_bit_dtype_signed_map[num_bits]
+ valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys()
+ if num_bits not in valid_bits:
+ raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}")
+ dtype_str = (
+ onnx_bit_dtype_unsigned_map[num_bits]
+ if has_zero_point
+ else onnx_bit_dtype_signed_map[num_bits]
+ )
return onnx_dtype_map[dtype_str]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: | |
"""Get the appropriate tensor dtype based on precision info and zero point presence. | |
Args: | |
num_bits: Number of bits for quantization | |
has_zero_point: Whether the tensor has a zero point | |
Returns: | |
ONNX tensor data type constant | |
""" | |
if has_zero_point: | |
dtype_str = onnx_bit_dtype_unsigned_map[num_bits] | |
else: | |
dtype_str = onnx_bit_dtype_signed_map[num_bits] | |
return onnx_dtype_map[dtype_str] | |
def get_tensor_dtype(num_bits: int = 4, has_zero_point: bool = False) -> int: | |
"""Get the appropriate tensor dtype based on precision info and zero point presence. | |
Args: | |
num_bits: Number of bits for quantization | |
has_zero_point: Whether the tensor has a zero point | |
Returns: | |
ONNX tensor data type constant | |
""" | |
valid_bits = onnx_bit_dtype_unsigned_map.keys() & onnx_bit_dtype_signed_map.keys() | |
if num_bits not in valid_bits: | |
raise ValueError(f"Unsupported num_bits={num_bits}; expected one of {sorted(valid_bits)}") | |
dtype_str = ( | |
onnx_bit_dtype_unsigned_map[num_bits] | |
if has_zero_point | |
else onnx_bit_dtype_signed_map[num_bits] | |
) | |
return onnx_dtype_map[dtype_str] |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 311 to 325, the function
does unguarded lookups on num_bits and mapping dicts which can raise KeyError
for unexpected values; validate that num_bits is one of the supported keys
(e.g., check presence in onnx_bit_dtype_unsigned_map and
onnx_bit_dtype_signed_map) before using it, raise a clear ValueError if
unsupported, and guard the final onnx_dtype_map lookup similarly (or provide a
default/raise with context) so the function always fails fast with a clear
message instead of crashing with a KeyError.
def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: | ||
"""Pads `w` to next largest multiple of block_size, on quantize_axis.""" | ||
assert quantize_axis <= len(w.shape), ( | ||
f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" | ||
) | ||
|
||
if w.shape[quantize_axis] % block_size == 0: | ||
return w | ||
|
||
pad_width = ( | ||
_next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] | ||
) | ||
pads = [(0, 0) for _ in range(len(w.shape))] | ||
pads[quantize_axis] = (0, pad_width) | ||
return np.pad(w, pads, mode="constant", constant_values=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out-of-bounds axis check in _pad (can crash).
quantize_axis <= len(w.shape)
allows ==
which is invalid; indexing will fail.
Apply:
- assert quantize_axis <= len(w.shape), (
+ assert 0 <= quantize_axis < len(w.shape), (
f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}"
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: | |
"""Pads `w` to next largest multiple of block_size, on quantize_axis.""" | |
assert quantize_axis <= len(w.shape), ( | |
f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" | |
) | |
if w.shape[quantize_axis] % block_size == 0: | |
return w | |
pad_width = ( | |
_next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] | |
) | |
pads = [(0, 0) for _ in range(len(w.shape))] | |
pads[quantize_axis] = (0, pad_width) | |
return np.pad(w, pads, mode="constant", constant_values=0) | |
def _pad(w: np.ndarray, block_size: int, quantize_axis: int = 0) -> np.ndarray: | |
"""Pads `w` to next largest multiple of block_size, on quantize_axis.""" | |
assert 0 <= quantize_axis < len(w.shape), ( | |
f"incorrect quantize-axis {quantize_axis}, w-shape={w.shape}" | |
) | |
if w.shape[quantize_axis] % block_size == 0: | |
return w | |
pad_width = ( | |
_next_block_size_multiple(w.shape[quantize_axis], block_size) - w.shape[quantize_axis] | |
) | |
pads = [(0, 0) for _ in range(len(w.shape))] | |
pads[quantize_axis] = (0, pad_width) | |
return np.pad(w, pads, mode="constant", constant_values=0) |
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/quant_utils.py around lines 200 to 215, the
assertion `quantize_axis <= len(w.shape)` incorrectly allows equality and can
cause an out-of-bounds index; replace it with a proper bounds check such as
`assert 0 <= quantize_axis < len(w.shape)` (or if negative axes should be
allowed, use `assert -len(w.shape) <= quantize_axis < len(w.shape)`) and keep
the same error message/context so indexing into w.shape cannot fail.
def find_scales( | ||
w: np.ndarray, | ||
block_size: int, | ||
quantize_axis: int = 0, | ||
alpha: float = 1.0, | ||
use_zero_point: bool = False, | ||
num_bits: int = 4, | ||
): | ||
"""Find scale factors for `w` via `s = max(w.block(block_size)) / 7`.""" | ||
w = _pad(w, block_size, quantize_axis) | ||
if quantize_axis == 0: | ||
w = w.T | ||
|
||
s_last_dim = w.shape[-1] // block_size | ||
s_shape = list(w.shape) | ||
s_shape[-1] = s_last_dim | ||
z = None | ||
if not use_zero_point: | ||
scale = 2 ** (num_bits - 1) - 1 | ||
w_amax = np.abs(w.reshape(-1, block_size)).max(axis=-1) | ||
s = (w_amax * alpha) / scale | ||
s = s.reshape(s_shape) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent division by zero: clamp scales in both branches.
When all-zeros blocks occur, s
can be 0 in the no-zero-point path; later division in rtn
will hit inf/NaN.
Apply:
- s = (w_amax * alpha) / scale
+ s = (w_amax * alpha) / scale
+ s = np.clip(s, CLIP_MIN, None)
@@
- s = s.reshape(s_shape)
+ s = np.clip(s, CLIP_MIN, None).reshape(s_shape)
Also applies to: 269-289
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/quant_utils.py around lines 242-264 (and also
apply same fix to lines 269-289), the computed scale array s can be zero for
all-zero blocks which will cause division-by-zero later; clamp s to a small
positive epsilon after computing it in both the no-zero-point and zero-point
branches (e.g. s = np.maximum(s, 1e-8) or set zero entries to eps) so downstream
divisions never produce inf/NaN, preserving the original shapes and dtype.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #310 +/- ##
==========================================
- Coverage 73.88% 73.63% -0.25%
==========================================
Files 172 172
Lines 17444 17587 +143
==========================================
+ Hits 12889 12951 +62
- Misses 4555 4636 +81 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
"--enable_mixed_quant", | ||
default=False, | ||
action="store_true", | ||
help="True when we want to use mixed quantization", | ||
) | ||
parser.add_argument( | ||
"--int8_layers", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use just assume mixed_quant enabled if --int8_layers
is non-empty? And remove the --enable_mixed_quant
option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if --int8_layers is specified we will select only those layers to be quantized in int8 which match the patter in int8_layers, else if only --enable_mixed_quant is specified we hardcode select few important layers similar to what some other quantization tools like model builder/llama.cpp are doing.
example:
if "python quantize.py ... -int8_layer="layer.0" --enable_mixed_quant" => all layer.0 will be quantized to int8
else "python quantize.py ... --enable_mixed_quant" => quantize to int8 the first 1/8, last 1/8 and every 3rd layer for below attn and ffn layers.
/model/layers.{i}/attn/qkv_proj/MatMul
/model/layers.{i}/attn/v_proj/MatMul
/model/layers.{i}/mlp/down_proj/MatMul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, make sense now.
default="", | ||
help=( | ||
"Comma-separated list of layer patterns to quantize to INT8 instead of INT4." | ||
"Example: 'layers.0,layers.1,lm_head'" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Overrides default mixed quant strategy. Example: 'layers.0,lm_head
"--enable_mixed_quant", | ||
default=False, | ||
action="store_true", | ||
help="True when we want to use mixed quantization", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: help="Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd layer quantized to INT8; others to INT4.",
…+INT8) ONNX models
What does this PR do?
Type of change: new feature
Overview: Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models
Usage
Testing
Tested sanity/functional testing, MMLU, perf for below models
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Documentation