-
Notifications
You must be signed in to change notification settings - Fork 162
[5506930]Add support in ModelOpt for generating mixed-precision (INT4… #310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -365,7 +365,7 @@ def main(args): | |
f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, " | ||
f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, " | ||
f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, " | ||
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32}\n" | ||
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}\n" | ||
) | ||
|
||
print( | ||
|
@@ -435,6 +435,8 @@ def main(args): | |
awqclip_alpha_step=args.awqclip_alpha_step, | ||
awqclip_alpha_min=args.awqclip_alpha_min, | ||
awqclip_bsz_col=args.awqclip_bsz_col, | ||
enable_mixed_quant=args.enable_mixed_quant, | ||
int8_layers=args.int8_layers, | ||
) | ||
logging.info(f"\nQuantization process took {time.time() - t} seconds") | ||
|
||
|
@@ -594,6 +596,20 @@ def main(args): | |
default=False, | ||
action="store_true", | ||
) | ||
|
||
parser.add_argument( | ||
"--enable_mixed_quant", | ||
default=False, | ||
action="store_true", | ||
help="True when we want to use mixed quantization", | ||
) | ||
parser.add_argument( | ||
"--int8_layers", | ||
Comment on lines
+600
to
+606
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use just assume mixed_quant enabled if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if --int8_layers is specified we will select only those layers to be quantized in int8 which match the patter in int8_layers, else if only --enable_mixed_quant is specified we hardcode select few important layers similar to what some other quantization tools like model builder/llama.cpp are doing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha, make sense now. |
||
type=str, | ||
default="", | ||
help=( | ||
"Comma-separated list of layer patterns to quantize to INT8 instead of INT4." | ||
"Example: 'layers.0,layers.1,lm_head'" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
), | ||
) | ||
args = parser.parse_args() | ||
main(args) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,6 +18,7 @@ | |||||||||||||||||||||||||||||||||||||||
import re | ||||||||||||||||||||||||||||||||||||||||
from collections import defaultdict | ||||||||||||||||||||||||||||||||||||||||
from functools import reduce | ||||||||||||||||||||||||||||||||||||||||
from typing import Any, cast | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||
import onnx | ||||||||||||||||||||||||||||||||||||||||
|
@@ -625,6 +626,223 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None) | |||||||||||||||||||||||||||||||||||||||
return nodes_to_exclude | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def _find_quantizable_weights( | ||||||||||||||||||||||||||||||||||||||||
graph: onnx.GraphProto, | ||||||||||||||||||||||||||||||||||||||||
nodes_to_exclude: list[str], | ||||||||||||||||||||||||||||||||||||||||
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]: | ||||||||||||||||||||||||||||||||||||||||
"""Finds the quantizable weights from the graph.""" | ||||||||||||||||||||||||||||||||||||||||
wa_pack = [] | ||||||||||||||||||||||||||||||||||||||||
gemm_nodes = [ | ||||||||||||||||||||||||||||||||||||||||
node | ||||||||||||||||||||||||||||||||||||||||
for node in graph.node | ||||||||||||||||||||||||||||||||||||||||
if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude | ||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||
initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)} | ||||||||||||||||||||||||||||||||||||||||
for gemm in gemm_nodes: | ||||||||||||||||||||||||||||||||||||||||
if gemm.input[0] in initializer_idxs: | ||||||||||||||||||||||||||||||||||||||||
# Ex. two const input to MatMul_115 in fastvit0.onnx | ||||||||||||||||||||||||||||||||||||||||
# Note. RTN algorithm will quantize these weights though | ||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
if gemm.input[1] not in initializer_idxs: | ||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]] | ||||||||||||||||||||||||||||||||||||||||
if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported | ||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
gemm_io_type = cast("int", weight_tensor.data_type) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
act_tensor = onnx.helper.ValueInfoProto() | ||||||||||||||||||||||||||||||||||||||||
act_tensor.name = gemm.input[0] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# TODO: support transA by transposing activation tensors in _clip_search | ||||||||||||||||||||||||||||||||||||||||
do_transpose = gemm.op_type == "Gemm" and any( | ||||||||||||||||||||||||||||||||||||||||
attr.name == "transB" and attr.i > 0 for attr in gemm.attribute | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type)) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
return wa_pack | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): | ||||||||||||||||||||||||||||||||||||||||
"""Check if layer should be quantized to INT8. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. | ||||||||||||||||||||||||||||||||||||||||
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
To match these, we: | ||||||||||||||||||||||||||||||||||||||||
- Remove the leading slash from the node name. | ||||||||||||||||||||||||||||||||||||||||
- Replace all '/' with '.' to match the naming convention of the initializer. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
This allows us to correctly identify which weights should be quantized to INT8. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
if not int8_layers: | ||||||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# Normalize both to dot-delimited tokens and require exact token sequence match. | ||||||||||||||||||||||||||||||||||||||||
def tokens(s: str) -> list[str]: | ||||||||||||||||||||||||||||||||||||||||
return s.lstrip("/").replace("/", ".").split(".") | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
hay = tokens(layer_name) | ||||||||||||||||||||||||||||||||||||||||
for pat in int8_layers: | ||||||||||||||||||||||||||||||||||||||||
needle = tokens(pat) | ||||||||||||||||||||||||||||||||||||||||
n, m = len(hay), len(needle) | ||||||||||||||||||||||||||||||||||||||||
for i in range(n - m + 1): | ||||||||||||||||||||||||||||||||||||||||
if hay[i : i + m] == needle: | ||||||||||||||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def validate_int8_layers(layers_str: str) -> bool: | ||||||||||||||||||||||||||||||||||||||||
"""Validate the format of int8_layers string.""" | ||||||||||||||||||||||||||||||||||||||||
if not layers_str: | ||||||||||||||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||||||||||||||
# Basic validation: check for valid characters and structure | ||||||||||||||||||||||||||||||||||||||||
import re | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
pattern = r"^[a-zA-Z0-9_.,\-]$" | ||||||||||||||||||||||||||||||||||||||||
return bool(re.match(pattern, layers_str)) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Comment on lines
+698
to
+707
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix regex validation to accept real layer patterns. The current pattern Apply: -def validate_int8_layers(layers_str: str) -> bool:
+def validate_int8_layers(layers_str: str) -> bool:
@@
- pattern = r"^[a-zA-Z0-9_.,\-]$"
- return bool(re.match(pattern, layers_str))
+ # Allow comma-separated list of path-like tokens
+ pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
+ return bool(re.match(pattern, layers_str)) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Look into this suggestion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain@i-riyad I've analyzed the suggestion and it's correct - there's a critical bug in the regex validation. The CodeRabbit suggestion is absolutely correct. The current regex has two major flaws:
The suggested fix properly addresses both issues by:
However, I'd recommend a small improvement to make the validation more robust: def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Allow comma-separated list of path-like tokens
# Pattern allows: letters, digits, underscore, dot, slash, hyphen
# Supports comma-separated values with optional whitespace
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
return bool(re.fullmatch(pattern, layers_str)) Using The suggestion should definitely be implemented to fix this validation bug. 🏁 Script executed:
Length of output: 634 Perfect! The test results confirm the analysis. The current implementation is severely broken - it only accepts single characters or empty strings, making it useless for validating real layer patterns. Key findings:
Recommendation: Implement the CodeRabbit suggestion immediately. This is a critical bug that makes the validation function completely non-functional for its intended purpose. Here's the corrected implementation: def validate_int8_layers(layers_str: str) -> bool:
"""Validate the format of int8_layers string."""
if not layers_str:
return True
# Basic validation: check for valid characters and structure
import re
# Allow comma-separated list of path-like tokens
pattern = r'^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$'
return bool(re.match(pattern, layers_str)) The current regex would cause the mixed-precision quantization feature to fail silently or reject all user-provided layer specifications, which would be a significant usability issue. |
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def get_layer_precision_mapping( | ||||||||||||||||||||||||||||||||||||||||
onnx_model: onnx.ModelProto, | ||||||||||||||||||||||||||||||||||||||||
int8_precision_pattern: str | None = None, | ||||||||||||||||||||||||||||||||||||||||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model. | ||||||||||||||||||||||||||||||||||||||||
Comment on lines
+709
to
+714
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Avoid mutable default args and normalize exclusions. Using a list default ( Apply: -def get_layer_precision_mapping(
+def get_layer_precision_mapping(
onnx_model: onnx.ModelProto,
int8_precision_pattern: str | None = None,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
):
@@
- nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+ nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) Also applies to: 728-731 🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||
onnx_model (onnx.ModelProto): The ONNX model to analyze. | ||||||||||||||||||||||||||||||||||||||||
int8_precision_pattern (str, optional): Comma-separated string of layer patterns to quantize to INT8. | ||||||||||||||||||||||||||||||||||||||||
If None, a default set of patterns is used to select layers for INT8 quantization. | ||||||||||||||||||||||||||||||||||||||||
nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization. | ||||||||||||||||||||||||||||||||||||||||
Defaults to [r"/lm_head"]. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||
dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "int8"}). | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
graph = onnx_model.graph | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) | ||||||||||||||||||||||||||||||||||||||||
# Collect quantizable weight tensors | ||||||||||||||||||||||||||||||||||||||||
wa_pack = _find_quantizable_weights(graph, nodes_to_exclude) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
if int8_precision_pattern: | ||||||||||||||||||||||||||||||||||||||||
if not validate_int8_layers(int8_precision_pattern): | ||||||||||||||||||||||||||||||||||||||||
raise ValueError("Invalid format for --int8_layers. Use comma-separated layers.") | ||||||||||||||||||||||||||||||||||||||||
int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||
matmul_nodes = [ | ||||||||||||||||||||||||||||||||||||||||
node | ||||||||||||||||||||||||||||||||||||||||
for node in onnx_model.graph.node | ||||||||||||||||||||||||||||||||||||||||
if node.op_type == "MatMul" and "lm_head" not in node.name | ||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# Only include nodes matching the specified patterns for all layers present in the model | ||||||||||||||||||||||||||||||||||||||||
# For example, for all i where a node exists with name: | ||||||||||||||||||||||||||||||||||||||||
# /model/layers.{i}/attn/qkv_proj/MatMul | ||||||||||||||||||||||||||||||||||||||||
# /model/layers.{i}/attn/v_proj/MatMul | ||||||||||||||||||||||||||||||||||||||||
# /model/layers.{i}/mlp/down_proj/MatMul | ||||||||||||||||||||||||||||||||||||||||
pattern_regexes = [ | ||||||||||||||||||||||||||||||||||||||||
re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"), | ||||||||||||||||||||||||||||||||||||||||
re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"), | ||||||||||||||||||||||||||||||||||||||||
re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"), | ||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# Filter matmul_nodes to only those matching the patterns | ||||||||||||||||||||||||||||||||||||||||
filtered_matmul_nodes = [] | ||||||||||||||||||||||||||||||||||||||||
for node in matmul_nodes: | ||||||||||||||||||||||||||||||||||||||||
for pat in pattern_regexes: | ||||||||||||||||||||||||||||||||||||||||
if pat.match(node.name): | ||||||||||||||||||||||||||||||||||||||||
filtered_matmul_nodes.append(node) | ||||||||||||||||||||||||||||||||||||||||
break | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# Build a mapping from group key to list of node names (ordered by layer index if possible) | ||||||||||||||||||||||||||||||||||||||||
def extract_group_key(node_name): | ||||||||||||||||||||||||||||||||||||||||
# Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul | ||||||||||||||||||||||||||||||||||||||||
parts = node_name.split("/") | ||||||||||||||||||||||||||||||||||||||||
if len(parts) >= 3: | ||||||||||||||||||||||||||||||||||||||||
return ".".join(parts[-3:-1]) | ||||||||||||||||||||||||||||||||||||||||
return node_name | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
group_to_nodes = {} | ||||||||||||||||||||||||||||||||||||||||
for node in filtered_matmul_nodes: | ||||||||||||||||||||||||||||||||||||||||
group_key = extract_group_key(node.name) | ||||||||||||||||||||||||||||||||||||||||
group_to_nodes.setdefault(group_key, []).append(node.name) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
int8_layers_set = set() | ||||||||||||||||||||||||||||||||||||||||
for names in group_to_nodes.values(): | ||||||||||||||||||||||||||||||||||||||||
n = len(names) | ||||||||||||||||||||||||||||||||||||||||
if n == 0: | ||||||||||||||||||||||||||||||||||||||||
continue | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# Try to sort by layer index if present | ||||||||||||||||||||||||||||||||||||||||
def layer_idx(name): | ||||||||||||||||||||||||||||||||||||||||
m = re.search(r"layers\.(\d+)\.", name) | ||||||||||||||||||||||||||||||||||||||||
return int(m.group(1)) if m else 0 | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
names_sorted = sorted(names, key=layer_idx) | ||||||||||||||||||||||||||||||||||||||||
first_eighth = int(n // 8) | ||||||||||||||||||||||||||||||||||||||||
last_eighth = int(n // 8) | ||||||||||||||||||||||||||||||||||||||||
# First 1/8 | ||||||||||||||||||||||||||||||||||||||||
int8_layers_set.update(names_sorted[:first_eighth]) | ||||||||||||||||||||||||||||||||||||||||
# Last 1/8 | ||||||||||||||||||||||||||||||||||||||||
if last_eighth > 0: | ||||||||||||||||||||||||||||||||||||||||
int8_layers_set.update(names_sorted[-last_eighth:]) | ||||||||||||||||||||||||||||||||||||||||
# Every third in the rest (excluding first and last eighth) | ||||||||||||||||||||||||||||||||||||||||
rest_start = first_eighth | ||||||||||||||||||||||||||||||||||||||||
rest_end = n - last_eighth | ||||||||||||||||||||||||||||||||||||||||
for i in range(rest_start, rest_end): | ||||||||||||||||||||||||||||||||||||||||
if (i - rest_start) % 3 == 0: | ||||||||||||||||||||||||||||||||||||||||
int8_layers_set.add(names_sorted[i]) | ||||||||||||||||||||||||||||||||||||||||
int8_layers_list = list(int8_layers_set) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# NEW: Create precision info mapping | ||||||||||||||||||||||||||||||||||||||||
precision_info = {} | ||||||||||||||||||||||||||||||||||||||||
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack): | ||||||||||||||||||||||||||||||||||||||||
weight_name = weight_tensor.name | ||||||||||||||||||||||||||||||||||||||||
if should_quantize_to_int8(weight_name, int8_layers_list): | ||||||||||||||||||||||||||||||||||||||||
precision_info[weight_name] = 8 | ||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||
precision_info[weight_name] = 4 | ||||||||||||||||||||||||||||||||||||||||
return precision_info | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def get_precision_info( | ||||||||||||||||||||||||||||||||||||||||
onnx_model: onnx.ModelProto, | ||||||||||||||||||||||||||||||||||||||||
nodes_to_exclude: list[str] | None = [r"/lm_head"], | ||||||||||||||||||||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||
"""Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits). | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
This function determines the quantization precision for each weight tensor in the ONNX model, | ||||||||||||||||||||||||||||||||||||||||
based on the provided configuration. If mixed quantization is enabled, it uses the layer | ||||||||||||||||||||||||||||||||||||||||
precision mapping; otherwise, it returns None. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||
onnx_model (onnx.ModelProto): The ONNX model to analyze. | ||||||||||||||||||||||||||||||||||||||||
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization. | ||||||||||||||||||||||||||||||||||||||||
**kwargs: Additional keyword arguments, such as: | ||||||||||||||||||||||||||||||||||||||||
- enable_mixed_quant (bool): Whether to enable mixed quantization. | ||||||||||||||||||||||||||||||||||||||||
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8. | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||
dict[str, int] | None: A mapping from weight tensor names to their quantization precision, | ||||||||||||||||||||||||||||||||||||||||
or None if mixed quantization is not enabled. | ||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||
precision_info = None | ||||||||||||||||||||||||||||||||||||||||
enable_mixed_quant = kwargs.get("enable_mixed_quant", False) | ||||||||||||||||||||||||||||||||||||||||
int8_layers = kwargs.get("int8_layers") | ||||||||||||||||||||||||||||||||||||||||
if enable_mixed_quant: | ||||||||||||||||||||||||||||||||||||||||
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude) | ||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||
precision_info = None | ||||||||||||||||||||||||||||||||||||||||
return precision_info | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
Comment on lines
+814
to
+844
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Same mutable default issue in public API. Apply same fix as above. -def get_precision_info(
+def get_precision_info(
onnx_model: onnx.ModelProto,
- nodes_to_exclude: list[str] | None = [r"/lm_head"],
+ nodes_to_exclude: list[str] | None = None,
**kwargs: Any,
):
@@
- enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
+ nodes_to_exclude = nodes_to_exclude or [r"/lm_head"]
+ enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def expand_node_names_from_patterns( | ||||||||||||||||||||||||||||||||||||||||
graph: onnx.GraphProto | Graph, name_patterns: list[str] | None = None | ||||||||||||||||||||||||||||||||||||||||
) -> list[str]: | ||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
help="Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd layer quantized to INT8; others to INT4.",