|
18 | 18 | import re
|
19 | 19 | from collections import defaultdict
|
20 | 20 | from functools import reduce
|
| 21 | +from typing import Any, cast |
21 | 22 |
|
22 | 23 | import numpy as np
|
23 | 24 | import onnx
|
@@ -625,6 +626,223 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None)
|
625 | 626 | return nodes_to_exclude
|
626 | 627 |
|
627 | 628 |
|
| 629 | +def _find_quantizable_weights( |
| 630 | + graph: onnx.GraphProto, |
| 631 | + nodes_to_exclude: list[str], |
| 632 | +) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]: |
| 633 | + """Finds the quantizable weights from the graph.""" |
| 634 | + wa_pack = [] |
| 635 | + gemm_nodes = [ |
| 636 | + node |
| 637 | + for node in graph.node |
| 638 | + if node.op_type in ["Gemm", "MatMul"] and node.name not in nodes_to_exclude |
| 639 | + ] |
| 640 | + initializer_idxs = {initializer.name: idx for idx, initializer in enumerate(graph.initializer)} |
| 641 | + for gemm in gemm_nodes: |
| 642 | + if gemm.input[0] in initializer_idxs: |
| 643 | + # Ex. two const input to MatMul_115 in fastvit0.onnx |
| 644 | + # Note. RTN algorithm will quantize these weights though |
| 645 | + continue |
| 646 | + |
| 647 | + if gemm.input[1] not in initializer_idxs: |
| 648 | + continue |
| 649 | + |
| 650 | + weight_tensor = graph.initializer[initializer_idxs[gemm.input[1]]] |
| 651 | + if len(weight_tensor.dims) == 1: # 1D blocked quantization not supported |
| 652 | + continue |
| 653 | + |
| 654 | + gemm_io_type = cast("int", weight_tensor.data_type) |
| 655 | + |
| 656 | + act_tensor = onnx.helper.ValueInfoProto() |
| 657 | + act_tensor.name = gemm.input[0] |
| 658 | + |
| 659 | + # TODO: support transA by transposing activation tensors in _clip_search |
| 660 | + do_transpose = gemm.op_type == "Gemm" and any( |
| 661 | + attr.name == "transB" and attr.i > 0 for attr in gemm.attribute |
| 662 | + ) |
| 663 | + |
| 664 | + wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type)) |
| 665 | + |
| 666 | + return wa_pack |
| 667 | + |
| 668 | + |
| 669 | +def should_quantize_to_int8(layer_name: str, int8_layers: list[str]): |
| 670 | + """Check if layer should be quantized to INT8. |
| 671 | +
|
| 672 | + The int8_layers list contains ONNX node names like '/model/layers.13/attn/qkv_proj/MatMul'. |
| 673 | + The layer_name argument is an ONNX initializer name like 'model.layers.13.attn.qkv_proj.MatMul.weight'. |
| 674 | +
|
| 675 | + To match these, we: |
| 676 | + - Remove the leading slash from the node name. |
| 677 | + - Replace all '/' with '.' to match the naming convention of the initializer. |
| 678 | +
|
| 679 | + This allows us to correctly identify which weights should be quantized to INT8. |
| 680 | + """ |
| 681 | + if not int8_layers: |
| 682 | + return False |
| 683 | + |
| 684 | + # Normalize both to dot-delimited tokens and require exact token sequence match. |
| 685 | + def tokens(s: str) -> list[str]: |
| 686 | + return s.lstrip("/").replace("/", ".").split(".") |
| 687 | + |
| 688 | + hay = tokens(layer_name) |
| 689 | + for pat in int8_layers: |
| 690 | + needle = tokens(pat) |
| 691 | + n, m = len(hay), len(needle) |
| 692 | + for i in range(n - m + 1): |
| 693 | + if hay[i : i + m] == needle: |
| 694 | + return True |
| 695 | + return False |
| 696 | + |
| 697 | + |
| 698 | +def validate_int8_layers(layers_str: str) -> bool: |
| 699 | + """Validate the format of int8_layers string.""" |
| 700 | + if not layers_str: |
| 701 | + return True |
| 702 | + # Basic validation: check for valid characters and structure |
| 703 | + import re |
| 704 | + |
| 705 | + pattern = r"^[a-zA-Z0-9_.,\-]$" |
| 706 | + return bool(re.match(pattern, layers_str)) |
| 707 | + |
| 708 | + |
| 709 | +def get_layer_precision_mapping( |
| 710 | + onnx_model: onnx.ModelProto, |
| 711 | + int8_precision_pattern: str | None = None, |
| 712 | + nodes_to_exclude: list[str] | None = [r"/lm_head"], |
| 713 | +): |
| 714 | + """Generate a mapping of layer names to their quantization precision (INT4 or INT8) for an ONNX model. |
| 715 | +
|
| 716 | + Args: |
| 717 | + onnx_model (onnx.ModelProto): The ONNX model to analyze. |
| 718 | + int8_precision_pattern (str, optional): Comma-separated string of layer patterns to quantize to INT8. |
| 719 | + If None, a default set of patterns is used to select layers for INT8 quantization. |
| 720 | + nodes_to_exclude (list[str], optional): List of node name patterns to exclude from quantization. |
| 721 | + Defaults to [r"/lm_head"]. |
| 722 | +
|
| 723 | + Returns: |
| 724 | + dict: A mapping from layer names to their quantization precision (e.g., {"layer_name": "int8"}). |
| 725 | + """ |
| 726 | + graph = onnx_model.graph |
| 727 | + |
| 728 | + nodes_to_exclude = expand_node_names_from_patterns(graph, nodes_to_exclude) |
| 729 | + # Collect quantizable weight tensors |
| 730 | + wa_pack = _find_quantizable_weights(graph, nodes_to_exclude) |
| 731 | + |
| 732 | + if int8_precision_pattern: |
| 733 | + if not validate_int8_layers(int8_precision_pattern): |
| 734 | + raise ValueError("Invalid format for --int8_layers. Use comma-separated layers.") |
| 735 | + int8_layers_list = [x.strip() for x in int8_precision_pattern.split(",") if x.strip()] |
| 736 | + |
| 737 | + else: |
| 738 | + matmul_nodes = [ |
| 739 | + node |
| 740 | + for node in onnx_model.graph.node |
| 741 | + if node.op_type == "MatMul" and "lm_head" not in node.name |
| 742 | + ] |
| 743 | + |
| 744 | + # Only include nodes matching the specified patterns for all layers present in the model |
| 745 | + # For example, for all i where a node exists with name: |
| 746 | + # /model/layers.{i}/attn/qkv_proj/MatMul |
| 747 | + # /model/layers.{i}/attn/v_proj/MatMul |
| 748 | + # /model/layers.{i}/mlp/down_proj/MatMul |
| 749 | + pattern_regexes = [ |
| 750 | + re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"), |
| 751 | + re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"), |
| 752 | + re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"), |
| 753 | + ] |
| 754 | + |
| 755 | + # Filter matmul_nodes to only those matching the patterns |
| 756 | + filtered_matmul_nodes = [] |
| 757 | + for node in matmul_nodes: |
| 758 | + for pat in pattern_regexes: |
| 759 | + if pat.match(node.name): |
| 760 | + filtered_matmul_nodes.append(node) |
| 761 | + break |
| 762 | + |
| 763 | + # Build a mapping from group key to list of node names (ordered by layer index if possible) |
| 764 | + def extract_group_key(node_name): |
| 765 | + # Extract the two components before 'MatMul' in the name, e.g. ...foo.bar.MatMul |
| 766 | + parts = node_name.split("/") |
| 767 | + if len(parts) >= 3: |
| 768 | + return ".".join(parts[-3:-1]) |
| 769 | + return node_name |
| 770 | + |
| 771 | + group_to_nodes = {} |
| 772 | + for node in filtered_matmul_nodes: |
| 773 | + group_key = extract_group_key(node.name) |
| 774 | + group_to_nodes.setdefault(group_key, []).append(node.name) |
| 775 | + |
| 776 | + int8_layers_set = set() |
| 777 | + for names in group_to_nodes.values(): |
| 778 | + n = len(names) |
| 779 | + if n == 0: |
| 780 | + continue |
| 781 | + |
| 782 | + # Try to sort by layer index if present |
| 783 | + def layer_idx(name): |
| 784 | + m = re.search(r"layers\.(\d+)\.", name) |
| 785 | + return int(m.group(1)) if m else 0 |
| 786 | + |
| 787 | + names_sorted = sorted(names, key=layer_idx) |
| 788 | + first_eighth = int(n // 8) |
| 789 | + last_eighth = int(n // 8) |
| 790 | + # First 1/8 |
| 791 | + int8_layers_set.update(names_sorted[:first_eighth]) |
| 792 | + # Last 1/8 |
| 793 | + if last_eighth > 0: |
| 794 | + int8_layers_set.update(names_sorted[-last_eighth:]) |
| 795 | + # Every third in the rest (excluding first and last eighth) |
| 796 | + rest_start = first_eighth |
| 797 | + rest_end = n - last_eighth |
| 798 | + for i in range(rest_start, rest_end): |
| 799 | + if (i - rest_start) % 3 == 0: |
| 800 | + int8_layers_set.add(names_sorted[i]) |
| 801 | + int8_layers_list = list(int8_layers_set) |
| 802 | + |
| 803 | + # NEW: Create precision info mapping |
| 804 | + precision_info = {} |
| 805 | + for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack): |
| 806 | + weight_name = weight_tensor.name |
| 807 | + if should_quantize_to_int8(weight_name, int8_layers_list): |
| 808 | + precision_info[weight_name] = 8 |
| 809 | + else: |
| 810 | + precision_info[weight_name] = 4 |
| 811 | + return precision_info |
| 812 | + |
| 813 | + |
| 814 | +def get_precision_info( |
| 815 | + onnx_model: onnx.ModelProto, |
| 816 | + nodes_to_exclude: list[str] | None = [r"/lm_head"], |
| 817 | + **kwargs: Any, |
| 818 | +): |
| 819 | + """Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits). |
| 820 | +
|
| 821 | + This function determines the quantization precision for each weight tensor in the ONNX model, |
| 822 | + based on the provided configuration. If mixed quantization is enabled, it uses the layer |
| 823 | + precision mapping; otherwise, it returns None. |
| 824 | +
|
| 825 | + Args: |
| 826 | + onnx_model (onnx.ModelProto): The ONNX model to analyze. |
| 827 | + nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization. |
| 828 | + **kwargs: Additional keyword arguments, such as: |
| 829 | + - enable_mixed_quant (bool): Whether to enable mixed quantization. |
| 830 | + - int8_layers (str): Comma-separated list of layer patterns to quantize to INT8. |
| 831 | +
|
| 832 | + Returns: |
| 833 | + dict[str, int] | None: A mapping from weight tensor names to their quantization precision, |
| 834 | + or None if mixed quantization is not enabled. |
| 835 | + """ |
| 836 | + precision_info = None |
| 837 | + enable_mixed_quant = kwargs.get("enable_mixed_quant", False) |
| 838 | + int8_layers = kwargs.get("int8_layers") |
| 839 | + if enable_mixed_quant: |
| 840 | + precision_info = get_layer_precision_mapping(onnx_model, int8_layers, nodes_to_exclude) |
| 841 | + else: |
| 842 | + precision_info = None |
| 843 | + return precision_info |
| 844 | + |
| 845 | + |
628 | 846 | def expand_node_names_from_patterns(
|
629 | 847 | graph: onnx.GraphProto | Graph, name_patterns: list[str] | None = None
|
630 | 848 | ) -> list[str]:
|
|
0 commit comments