Skip to content

Commit b069cec

Browse files
committed
[5506930]Add support in ModelOpt for generating mixed-precision (INT4+INT8) ONNX models, handle comments and rename functions,variables
Signed-off-by: unknown <[email protected]>
1 parent b7b134e commit b069cec

File tree

6 files changed

+94
-62
lines changed

6 files changed

+94
-62
lines changed

examples/windows/onnx_ptq/genai_llm/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ The table below lists key command-line arguments of the ONNX PTQ example script.
5757
| `--calibration_eps` | dml, cuda, cpu, NvTensorRtRtx (default: [dml,cpu]) | List of execution-providers to use for session run during calibration |
5858
| `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data|
5959
| `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization|
60+
| `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy|
6061

6162
Run the following command to view all available parameters in the script:
6263

examples/windows/onnx_ptq/genai_llm/quantize.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,15 @@ def main(args):
361361
# device = torch.device(f"cuda:{device_id}")
362362
device = torch.device(args.device)
363363

364+
if args.layers_8bit:
365+
args.enable_mixed_quant = True
366+
364367
print(
365368
f"\n--Quantize-Script-- algo={args.algo}, dataset={args.dataset}, calib_size={args.calib_size}, "
366369
f"batch_size={args.batch_size}, block_size={args.block_size}, add-position-ids={args.add_position_ids}, "
367370
f"past-kv={args.add_past_kv_inputs}, rcalib={args.use_random_calib}, device={args.device}, "
368-
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}\n"
371+
f"use_zero_point={args.use_zero_point}, use_fp32={args.use_fp32} enable_mixed_quant={args.enable_mixed_quant}, "
372+
f"layers_8bit={args.layers_8bit}\n"
369373
)
370374

371375
print(
@@ -436,7 +440,7 @@ def main(args):
436440
awqclip_alpha_min=args.awqclip_alpha_min,
437441
awqclip_bsz_col=args.awqclip_bsz_col,
438442
enable_mixed_quant=args.enable_mixed_quant,
439-
int8_layers=args.int8_layers,
443+
layers_8bit=args.layers_8bit,
440444
)
441445
logging.info(f"\nQuantization process took {time.time() - t} seconds")
442446

@@ -600,16 +604,16 @@ def main(args):
600604
"--enable_mixed_quant",
601605
default=False,
602606
action="store_true",
603-
help="True when we want to use mixed quantization",
607+
help=(
608+
"Use default mixed quantization strategy: first 1/8, last 1/8, and every 3rd attn, "
609+
"mlp layers quantized to 8 bits; others to 4 bits."
610+
),
604611
)
605612
parser.add_argument(
606-
"--int8_layers",
613+
"--layers_8bit",
607614
type=str,
608615
default="",
609-
help=(
610-
"Comma-separated list of layer patterns to quantize to INT8 instead of INT4."
611-
"Example: 'layers.0,layers.1,lm_head'"
612-
),
616+
help=("Overrides default mixed quant strategy. Example: 'layers.0,lm_head'"),
613617
)
614618
args = parser.parse_args()
615619
main(args)

modelopt/onnx/quantization/graph_utils.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,11 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None)
626626
return nodes_to_exclude
627627

628628

629-
def _find_quantizable_weights(
629+
def _find_int4_quantizable_weights(
630630
graph: onnx.GraphProto,
631631
nodes_to_exclude: list[str],
632632
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]:
633-
"""Finds the quantizable weights from the graph."""
633+
"""Finds the int4 quantizable weights from the graph."""
634634
wa_pack = []
635635
gemm_nodes = [
636636
node
@@ -666,27 +666,27 @@ def _find_quantizable_weights(
666666
return wa_pack
667667

668668

669-
def should_quantize_to_int8(layer_name: str, int8_layers: list[str]):
670-
"""Check if layer should be quantized to INT8.
669+
def should_quantize_to_8bit(layer_name: str, layers_8bit: list[str]):
670+
"""Check if layer should be quantized to 8 bits.
671671
672-
The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
672+
The layers_8bit list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'.
673673
The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'.
674674
675675
To match these, we:
676676
- Remove the leading slash from the node name.
677677
- Replace all '/' with '.' to match the naming convention of the initializer.
678678
679-
This allows us to correctly identify which weights should be quantized to INT8.
679+
This allows us to correctly identify which weights should be quantized to 8 bits.
680680
"""
681-
if not int8_layers:
681+
if not layers_8bit:
682682
return False
683683

684684
# Normalize both to dot-delimited tokens and require exact token sequence match.
685685
def tokens(s: str) -> list[str]:
686686
return s.lstrip("/").replace("/", ".").split(".")
687687

688688
hay = tokens(layer_name)
689-
for pat in int8_layers:
689+
for pat in layers_8bit:
690690
needle = tokens(pat)
691691
n, m = len(hay), len(needle)
692692
for i in range(n - m + 1):
@@ -695,44 +695,42 @@ def tokens(s: str) -> list[str]:
695695
return False
696696

697697

698-
def validate_int8_layers(layers_str: str) -> bool:
699-
"""Validate the format of int8_layers string."""
698+
def validate_8bit_layers(layers_str: str) -> bool:
699+
"""Validate the format of layers_8bit string."""
700700
if not layers_str:
701701
return True
702-
# Basic validation: check for valid characters and structure
703-
import re
704-
705-
pattern = r"^[a-zA-Z0-9_.,\-]$"
702+
# Allow comma-separated list of path-like tokens
703+
pattern = r"^\s*[/a-zA-Z0-9_.\-]+(\s*,\s*[/a-zA-Z0-9_.\-]+)*\s*$"
706704
return bool(re.match(pattern, layers_str))
707705

708706

709707
def get_layer_precision_mapping(
710708
onnx_model: onnx.ModelProto,
711-
int8_precision_pattern: str | None = None,
709+
precision_pattern_8bit: str | None = None,
712710
nodes_to_exclude: list[str] | None = [r"/lm_head"],
713711
):
714-
"""Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model.
712+
"""Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model.
715713
716714
Args:
717715
onnx_model (onnx.ModelProto): The ONNX model to analyze.
718-
int8_precision_pattern (str, optional): Comma-separated string of layer patterns to quantize to INT8.
719-
If None, a default set of patterns is used to select layers for INT8 quantization.
716+
precision_pattern_8bit (str, optional): Comma-separated string of layer patterns to quantize to 8 bits.
717+
If None, a default set of patterns is used to select layers for 8 bits quantization.
720718
nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization.
721719
Defaults to [r"/lm_head"].
722720
723721
Returns:
724-
dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "int8"}).
722+
dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "8"}).
725723
"""
726724
graph = onnx_model.graph
727725

728726
nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude)
729727
# Collect quantizable weight tensors
730-
wa_pack = _find_quantizable_weights(graph, nodes_to_exclude)
728+
wa_pack = _find_int4_quantizable_weights(graph, nodes_to_exclude)
731729

732-
if int8_precision_pattern:
733-
if not validate_int8_layers(int8_precision_pattern):
734-
raise ValueError("Invalid format for --int8_layers. Use comma-separated layers.")
735-
int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()]
730+
if precision_pattern_8bit:
731+
if not validate_8bit_layers(precision_pattern_8bit):
732+
raise ValueError("Invalid format for --layers_8bit. Use comma-separated layers.")
733+
layers_list_8bit = [x.strip() for x in precision_pattern_8bit.split(",") if x.strip()]
736734

737735
else:
738736
matmul_nodes = [
@@ -773,7 +771,7 @@ def extract_group_key(node_name):
773771
group_key = extract_group_key(node.name)
774772
group_to_nodes.setdefault(group_key, []).append(node.name)
775773

776-
int8_layers_set = set()
774+
layers_8bit_set = set()
777775
for names in group_to_nodes.values():
778776
n = len(names)
779777
if n == 0:
@@ -788,23 +786,23 @@ def layer_idx(name):
788786
first_eighth = int(n // 8)
789787
last_eighth = int(n // 8)
790788
# First 1/8
791-
int8_layers_set.update(names_sorted[:first_eighth])
789+
layers_8bit_set.update(names_sorted[:first_eighth])
792790
# Last 1/8
793791
if last_eighth > 0:
794-
int8_layers_set.update(names_sorted[-last_eighth:])
792+
layers_8bit_set.update(names_sorted[-last_eighth:])
795793
# Every third in the rest (excluding first and last eighth)
796794
rest_start = first_eighth
797795
rest_end = n - last_eighth
798796
for i in range(rest_start, rest_end):
799797
if (i - rest_start) % 3 == 0:
800-
int8_layers_set.add(names_sorted[i])
801-
int8_layers_list = list(int8_layers_set)
798+
layers_8bit_set.add(names_sorted[i])
799+
layers_list_8bit = list(layers_8bit_set)
802800

803801
# NEW: Create precision info mapping
804802
precision_info = {}
805803
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
806804
weight_name = weight_tensor.name
807-
if should_quantize_to_int8(weight_name, int8_layers_list):
805+
if should_quantize_to_8bit(weight_name, layers_list_8bit):
808806
precision_info[weight_name] = 8
809807
else:
810808
precision_info[weight_name] = 4
@@ -827,17 +825,17 @@ def get_precision_info(
827825
nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization.
828826
**kwargs: Additional keyword arguments, such as:
829827
- enable_mixed_quant (bool): Whether to enable mixed quantization.
830-
- int8_layers (str): Comma-separated list of layer patterns to quantize to INT8.
828+
- layers_8bit (str): Comma-separated list of layer patterns to quantize to 8 bit.
831829
832830
Returns:
833831
dict[str, int] | None: A mapping from weight tensor names to their quantization precision,
834832
or None if mixed quantization is not enabled.
835833
"""
836834
precision_info = None
837835
enable_mixed_quant = kwargs.get("enable_mixed_quant", False)
838-
int8_layers = kwargs.get("int8_layers")
836+
layers_8bit = kwargs.get("layers_8bit")
839837
if enable_mixed_quant:
840-
precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude)
838+
precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude)
841839
else:
842840
precision_info = None
843841
return precision_info

modelopt/onnx/quantization/int4.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
from modelopt.onnx.op_types import is_fusible_scaling_op
3737
from modelopt.onnx.quantization.calib_utils import RandomDataProvider
3838
from modelopt.onnx.quantization.graph_utils import (
39-
_find_quantizable_weights,
39+
_find_int4_quantizable_weights as _find_quantizable_weights,
40+
)
41+
from modelopt.onnx.quantization.graph_utils import (
4042
expand_node_names_from_patterns,
4143
get_precision_info,
4244
get_tensor_consumer_nodes,
@@ -50,9 +52,9 @@
5052
find_scales,
5153
get_num_bits,
5254
quant_tensor,
55+
reshape_scales_for_per_channel_nodes,
5356
rtn,
5457
update_block_size,
55-
update_scale_map_for_per_channel_nodes,
5658
)
5759
from modelopt.onnx.utils import save_onnx
5860

@@ -121,6 +123,7 @@ def _quantize_gather_nodes(
121123
continue
122124
name = in_tensor.name
123125
w = in_tensor.values
126+
# Updating the block size as for 8bit quantization, per-channel quantization is used.
124127
num_bits = get_num_bits(precision_info, name)
125128
block_size_updated = update_block_size(
126129
num_bits, block_size, w=w, quantize_axis=gather_quantize_axis
@@ -170,7 +173,7 @@ def _quantize_gather_nodes(
170173
)
171174
else:
172175
logger.info("Found 0 Gather nodes to quantize")
173-
scales_map = update_scale_map_for_per_channel_nodes(scales_map, block_size, precision_info)
176+
scales_map = reshape_scales_for_per_channel_nodes(scales_map, block_size, precision_info)
174177
return weights_map, scales_map, zero_point_map
175178

176179

@@ -221,6 +224,7 @@ def quantize_rtn(
221224
precision_info = get_precision_info(onnx_model, nodes_to_exclude, **kwargs)
222225
for name, w in gemm_weights.items():
223226
logger.debug(f"Computing scales for weight {name} of shape {w.shape}")
227+
# Updating the block size as for 8bit quantization, per-channel quantization is used.
224228
num_bits = get_num_bits(precision_info, name)
225229
block_size_updated = update_block_size(num_bits, block_size, w=w)
226230
s, zp = find_scales(np.asarray(w), block_size_updated, num_bits=num_bits)
@@ -258,14 +262,15 @@ def quantize_rtn(
258262
gemm_weights_quantized = {}
259263
for name, w in gemm_weights.items():
260264
logger.debug(f"Quantizing weight {name}")
265+
# Updating the block size as for 8bit quantization, per-channel quantization is used.
261266
num_bits = get_num_bits(precision_info, name)
262267
block_size_updated = update_block_size(num_bits, block_size, w=w)
263268
qw = rtn(np.asarray(w), scales[name], block_size_updated, num_bits=num_bits)
264269
if has_cupy:
265270
qw = np.asnumpy(qw)
266271
scales[name] = np.asnumpy(scales[name])
267272
gemm_weights_quantized[name] = numpy.asarray(qw)
268-
scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info)
273+
scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info)
269274
qdq.insert_dq_nodes(
270275
graph,
271276
scales,
@@ -285,7 +290,7 @@ def quantize_rtn(
285290
if has_cupy:
286291
for name in scales:
287292
scales[name] = np.asnumpy(scales[name])
288-
scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info)
293+
scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info)
289294
qdq.insert_qdq_nodes(graph, scales, weight_map=gemm_tensors, precision_info=precision_info)
290295
if gather_w_map is not None:
291296
assert gather_s_map is not None, "scale-map not found for quantizable gather nodes"
@@ -497,6 +502,7 @@ def _quantize_awq_clip(
497502
w = w.T
498503
w = np.asarray(w)
499504
num_bits = get_num_bits(precision_info, weight_tensor.name)
505+
# Updating the block size as for 8bit quantization, per-channel quantization is used.
500506
block_size_updated = update_block_size(num_bits, block_size, w=w)
501507
awq_clip = AWQClipHelper(w, block_size_updated, **kwargs)
502508
_clip_search(x, w, awq_clip, num_bits=num_bits, **kwargs)
@@ -524,7 +530,9 @@ def _quantize_awq_clip(
524530

525531
alpha = alphas.get(weight_tensor.name, 1)
526532
num_bits = get_num_bits(precision_info, weight_tensor.name)
527-
qw, scale, _ = quant_tensor(w, block_size, alpha=alpha, num_bits=num_bits)
533+
# Updating the block size as for 8bit quantization, per-channel quantization is used.
534+
block_size_updated = update_block_size(num_bits, block_size, w=w)
535+
qw, scale, _ = quant_tensor(w, block_size_updated, alpha=alpha, num_bits=num_bits)
528536
if has_cupy:
529537
qw = np.asnumpy(qw)
530538
scale = np.asnumpy(scale)
@@ -561,7 +569,7 @@ def _quantize_awq_clip(
561569

562570
t = time.time()
563571
dq_node_attributes = {"axis": 0, "block_size": block_size}
564-
scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info)
572+
scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info)
565573
qdq.insert_dq_nodes(
566574
graph_gs,
567575
scales,
@@ -716,6 +724,7 @@ def run_awq_scale_search_per_node(
716724
x = np.concatenate(output_dicts[act_tensor.name], axis=0).reshape(
717725
(-1, w.shape[0])
718726
) # n_token, ci
727+
# Updating the block size as for 8bit quantization, per-channel quantization is used.
719728
num_bits = get_num_bits(precision_info, weight_tensor.name)
720729
block_size_updated = update_block_size(num_bits, block_size, w=w)
721730
awq_lite[i] = AWQLiteHelper(x, w, block_size_updated, **kwargs)
@@ -1129,6 +1138,7 @@ def _quantize_awq_lite(
11291138
assert enable_weight_clipping or (alpha == 1), (
11301139
"clip range enabled without enabling weight-clipping param"
11311140
)
1141+
# Updating the block size as for 8bit quantization, per-channel quantization is used.
11321142
num_bits = get_num_bits(precision_info, weight_tensor.name)
11331143
block_size_updated = update_block_size(num_bits, block_size, w=w_scaled)
11341144
qw, scale, zp = quant_tensor(
@@ -1262,7 +1272,7 @@ def _quantize_awq_lite(
12621272

12631273
t = time.time()
12641274
dq_node_attributes = {"axis": 0, "block_size": block_size}
1265-
scales = update_scale_map_for_per_channel_nodes(scales, block_size, precision_info)
1275+
scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info)
12661276
qdq.insert_dq_nodes(
12671277
graph_gs,
12681278
scales,
@@ -1371,7 +1381,7 @@ def quantize(
13711381
Default: 32.
13721382
- **enable_mixed_quant** (bool): If True, enable mixed quantization.
13731383
Default: False.
1374-
- **int8_layers** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
1384+
- **layers_8bit** (str): comma-separated list of layer patterns to quantize to INT8 instead of INT4.
13751385
Default: [].
13761386
**Returns**: A quantized ONNX model in ONNX ModelProto format.
13771387
"""

0 commit comments

Comments
 (0)