|
40 | 40 | save_onnx, |
41 | 41 | ) |
42 | 42 |
|
| 43 | +DEFAULT_GATHER_BLOCK_SIZE = 32 |
| 44 | +DEFAULT_GATHER_QUANTIZE_AXIS = None |
| 45 | + |
43 | 46 |
|
44 | 47 | def is_const_input(tensor: Tensor) -> bool: |
45 | 48 | """Returns whether the given tensor is an initializer or produced by const-foldable nodes.""" |
@@ -718,6 +721,8 @@ def get_layer_precision_mapping( |
718 | 721 | onnx_model: onnx.ModelProto, |
719 | 722 | precision_pattern_8bit: str | None = None, |
720 | 723 | nodes_to_exclude: list[str] | None = [r"/lm_head"], |
| 724 | + block_size: int = 128, |
| 725 | + quantize_axis: int = 0, |
721 | 726 | ): |
722 | 727 | """Generate a mapping of layer names to their quantization precision (4 bits or 8 bits) for an ONNX model. |
723 | 728 |
|
@@ -746,7 +751,7 @@ def get_layer_precision_mapping( |
746 | 751 | matmul_nodes = [ |
747 | 752 | node |
748 | 753 | for node in onnx_model.graph.node |
749 | | - if node.op_type == "MatMul" and "lm_head" not in node.name |
| 754 | + if node.op_type in ["Gemm", "MatMul"] and "lm_head" not in node.name |
750 | 755 | ] |
751 | 756 |
|
752 | 757 | # Only include nodes matching the specified patterns for all layers present in the model |
@@ -808,47 +813,81 @@ def layer_idx(name): |
808 | 813 | layers_8bit_set.add(names_sorted[i]) |
809 | 814 | layers_list_8bit = list(layers_8bit_set) |
810 | 815 |
|
811 | | - # NEW: Create precision info mapping |
812 | | - precision_info = {} |
| 816 | + # NEW: Create layer info mapping with precision, block_size, and axis |
| 817 | + layer_info = {} |
813 | 818 | for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack): |
814 | 819 | weight_name = weight_tensor.name |
815 | 820 | if should_quantize_to_8bit(weight_name, layers_list_8bit): |
816 | | - precision_info[weight_name] = 8 |
| 821 | + layer_info[weight_name] = { |
| 822 | + "precision": 8, |
| 823 | + "block_size": -1, # Per-channel for 8-bit |
| 824 | + "axis": 0, |
| 825 | + } |
817 | 826 | else: |
818 | | - precision_info[weight_name] = 4 |
819 | | - return precision_info |
| 827 | + layer_info[weight_name] = { |
| 828 | + "precision": 4, |
| 829 | + "block_size": block_size, # Default block size for 4-bit |
| 830 | + "axis": quantize_axis, |
| 831 | + } |
| 832 | + |
| 833 | + return layer_info |
820 | 834 |
|
821 | 835 |
|
822 | | -def get_precision_info( |
| 836 | +def get_layer_info( |
823 | 837 | onnx_model: onnx.ModelProto, |
824 | 838 | nodes_to_exclude: list[str] | None = [r"/lm_head"], |
| 839 | + block_size: int = 128, |
| 840 | + quantize_axis: int = 0, |
825 | 841 | **kwargs: Any, |
826 | 842 | ): |
827 | | - """Generate a mapping of weight tensor names to their quantization precision (e.g., 4 or 8 bits). |
| 843 | + """Generate a mapping of weight tensor names to their quantization configuration. |
828 | 844 |
|
829 | | - This function determines the quantization precision for each weight tensor in the ONNX model, |
830 | | - based on the provided configuration. If mixed quantization is enabled, it uses the layer |
831 | | - precision mapping; otherwise, it returns None. |
| 845 | + This function determines the quantization configuration (precision, block_size, axis) for each |
| 846 | + weight tensor in the ONNX model, based on the provided configuration. If mixed quantization |
| 847 | + is enabled, it uses the layer precision mapping; otherwise, it returns None. |
832 | 848 |
|
833 | 849 | Args: |
834 | 850 | onnx_model (onnx.ModelProto): The ONNX model to analyze. |
835 | 851 | nodes_to_exclude (list[str] | None): List of node name patterns to exclude from quantization. |
836 | 852 | **kwargs: Additional keyword arguments, such as: |
837 | 853 | - enable_mixed_quant (bool): Whether to enable mixed quantization. |
838 | 854 | - layers_8bit (str): Comma-separated list of layer patterns to quantize to 8 bit. |
| 855 | + - block_size (int): Default block size for quantization. |
| 856 | + - quantize_axis (int): Default quantization axis. |
| 857 | + - gather_block_size (int): Default block size for gather quantization. |
| 858 | + - gather_quantize_axis (int): Default quantization axis for gather. |
839 | 859 |
|
840 | 860 | Returns: |
841 | | - dict[str, int] | None: A mapping from weight tensor names to their quantization precision, |
842 | | - or None if mixed quantization is not enabled. |
| 861 | + dict[str, dict[str, Any]] | None: A mapping from weight tensor names to their quantization |
| 862 | + configuration (with keys: precision, block_size, axis), or None if mixed quantization is not enabled. |
843 | 863 | """ |
844 | | - precision_info = None |
| 864 | + layer_info = None |
845 | 865 | enable_mixed_quant = kwargs.get("enable_mixed_quant", False) |
846 | 866 | layers_8bit = kwargs.get("layers_8bit") |
| 867 | + gather_block_size = kwargs.get("gather_block_size", DEFAULT_GATHER_BLOCK_SIZE) |
| 868 | + gather_quantize_axis = kwargs.get("gather_quantize_axis", DEFAULT_GATHER_QUANTIZE_AXIS) |
847 | 869 | if enable_mixed_quant: |
848 | | - precision_info = get_layer_precision_mapping(onnx_model, layers_8bit, nodes_to_exclude) |
| 870 | + layer_info = get_layer_precision_mapping( |
| 871 | + onnx_model, |
| 872 | + layers_8bit, |
| 873 | + nodes_to_exclude, |
| 874 | + block_size, |
| 875 | + quantize_axis, |
| 876 | + ) |
849 | 877 | else: |
850 | | - precision_info = None |
851 | | - return precision_info |
| 878 | + layer_info = None |
| 879 | + |
| 880 | + if gather_quantize_axis is not None: |
| 881 | + if layer_info is None: |
| 882 | + layer_info = {} |
| 883 | + for node in onnx_model.graph.node: |
| 884 | + if node.op_type == "Gather": |
| 885 | + layer_info[node.input[0]] = { |
| 886 | + "precision": 4, |
| 887 | + "block_size": gather_block_size, |
| 888 | + "axis": gather_quantize_axis, |
| 889 | + } |
| 890 | + return layer_info |
852 | 891 |
|
853 | 892 |
|
854 | 893 | def expand_node_names_from_patterns( |
|
0 commit comments