|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import List, Optional, Union |
| 16 | + |
| 17 | +import torch |
| 18 | +from compressed_tensors.config import CompressionFormat, SparsityStructure |
| 19 | +from compressed_tensors.quantization import ( |
| 20 | + QuantizationArgs, |
| 21 | + QuantizationStrategy, |
| 22 | + QuantizationType, |
| 23 | +) |
| 24 | +from compressed_tensors.quantization.utils import is_module_quantized |
| 25 | +from loguru import logger |
| 26 | + |
| 27 | + |
| 28 | +__all__ = ["infer_and_set_per_module_quantization_format"] |
| 29 | + |
| 30 | + |
| 31 | +def _get_quant_compression_format( |
| 32 | + input_args: QuantizationArgs, |
| 33 | + weight_args: QuantizationArgs, |
| 34 | + sparsity_structure: Optional[str] = None, |
| 35 | +) -> CompressionFormat: |
| 36 | + """ |
| 37 | + Using the weight and input quantization args as well as an optional |
| 38 | + sparsity structure, determine the compression format that should be |
| 39 | + applied to a given module |
| 40 | +
|
| 41 | + :param input_args: input quantization parameters |
| 42 | + :param weight_args: weight quantization parameters |
| 43 | + :param sparsity_structure: optional (global) modle sparsity |
| 44 | + structure |
| 45 | + :return CompresssionFormat for the module |
| 46 | + """ |
| 47 | + is_24_structure = ( |
| 48 | + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR |
| 49 | + ) |
| 50 | + is_weight_only = weight_args is not None and input_args is None |
| 51 | + |
| 52 | + if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: |
| 53 | + return CompressionFormat.nvfp4_pack_quantized |
| 54 | + |
| 55 | + if is_weight_only: # w4a16 and w8a16 |
| 56 | + is_valid_pack = ( |
| 57 | + weight_args.num_bits in [4, 8] |
| 58 | + and weight_args.type == QuantizationType.INT.value |
| 59 | + ) |
| 60 | + if not is_valid_pack: # packing only valid for int4 and int 8 |
| 61 | + return CompressionFormat.naive_quantized |
| 62 | + if is_24_structure: |
| 63 | + if ( |
| 64 | + weight_args.strategy is not QuantizationStrategy.CHANNEL.value |
| 65 | + and weight_args.strategy is not QuantizationStrategy.GROUP.value |
| 66 | + ): |
| 67 | + # marlin24 kernel only applicable for channel/group quantization |
| 68 | + return CompressionFormat.pack_quantized |
| 69 | + return CompressionFormat.marlin_24 |
| 70 | + return CompressionFormat.pack_quantized |
| 71 | + |
| 72 | + else: # w8a8 float and int |
| 73 | + if ( |
| 74 | + weight_args.type == QuantizationType.FLOAT.value |
| 75 | + and weight_args.num_bits == 8 |
| 76 | + ): |
| 77 | + return CompressionFormat.float_quantized |
| 78 | + if weight_args.type == QuantizationType.INT.value: |
| 79 | + return CompressionFormat.int_quantized |
| 80 | + |
| 81 | + return CompressionFormat.naive_quantized |
| 82 | + |
| 83 | + |
| 84 | +def set_per_module_format( |
| 85 | + module: torch.nn.Module, sparsity_structure: Optional[str] = None |
| 86 | +): |
| 87 | + """ |
| 88 | + Determine and set the per module quantization format given quantization args |
| 89 | + and sparsity structure. |
| 90 | +
|
| 91 | + :param module: module which has its quantization inferred |
| 92 | + :param sparisty_structure: optional sparsity applied to the module |
| 93 | +
|
| 94 | + """ |
| 95 | + weight_scheme = module.quantization_scheme.weights |
| 96 | + input_scheme = module.quantization_scheme.input_activations |
| 97 | + if weight_scheme is None: |
| 98 | + return # no weight quant - nothing to compress |
| 99 | + compression_format = _get_quant_compression_format( |
| 100 | + input_scheme, weight_scheme, sparsity_structure |
| 101 | + ) |
| 102 | + |
| 103 | + # If set, we check if it matches our inferred one |
| 104 | + if module.quantization_scheme.format is not None: |
| 105 | + # If it does not, warn the user |
| 106 | + if module.quantization_scheme.format != compression_format.value: |
| 107 | + logger.warning( |
| 108 | + "The provided format for the module does not match the " |
| 109 | + "inferred format. Compression may fail " |
| 110 | + ) |
| 111 | + else: |
| 112 | + # If not set, we set ours |
| 113 | + module.quantization_scheme.format = compression_format.value |
| 114 | + |
| 115 | + |
| 116 | +def infer_and_set_per_module_quantization_format( |
| 117 | + model: torch.nn.Module, |
| 118 | + sparsity_structure: Optional[str] = None, |
| 119 | +) -> Union[str, List[str]]: |
| 120 | + """ |
| 121 | + Infers the quantization format for a model based on its state and provided |
| 122 | + compression arguments. Updates thhe quantization_scheme.format value |
| 123 | + based on the inferred format. Returns the unique list of formats in the model |
| 124 | + or None if empty list |
| 125 | +
|
| 126 | + For a summary of the formats, see `docs/guides/compression_formats.md`. |
| 127 | +
|
| 128 | + :param model: model to check for quantization |
| 129 | + :param sparisty_structure: optional sparsity applied to the module |
| 130 | + :return compression format appropriate for model |
| 131 | + """ |
| 132 | + unique_formats = [] |
| 133 | + for submodule in model.modules(): |
| 134 | + if is_module_quantized(submodule): |
| 135 | + set_per_module_format(submodule, sparsity_structure) |
| 136 | + if submodule.quantization_scheme.format not in unique_formats: |
| 137 | + unique_formats.append(submodule.quantization_scheme.format) |
| 138 | + |
| 139 | + if len(unique_formats) > 0: |
| 140 | + return unique_formats |
| 141 | + return CompressionFormat.dense.value |
0 commit comments