diff --git a/setup.py b/setup.py index 58a1105f4..d60b8d8ce 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ def _setup_packages() -> List: ) def _setup_install_requires() -> List: - return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"] + return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict", "loguru"] def _setup_extras() -> Dict: return { diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index 582b8a9e1..62abb7938 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,5 +15,6 @@ # flake8: noqa from .base import * from .dense import * +from .format import * from .sparse_24_bitmask import * from .sparse_bitmask import * diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py new file mode 100644 index 000000000..39a8d812c --- /dev/null +++ b/src/compressed_tensors/config/format.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +from compressed_tensors.config import CompressionFormat, SparsityStructure +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) +from compressed_tensors.quantization.utils import is_module_quantized +from loguru import logger + + +__all__ = ["infer_and_set_per_module_quantization_format"] + + +def _get_quant_compression_format( + input_args: QuantizationArgs, + weight_args: QuantizationArgs, + sparsity_structure: Optional[str] = None, +) -> CompressionFormat: + """ + Using the weight and input quantization args as well as an optional + sparsity structure, determine the compression format that should be + applied to a given module + + :param input_args: input quantization parameters + :param weight_args: weight quantization parameters + :param sparsity_structure: optional (global) modle sparsity + structure + :return CompresssionFormat for the module + """ + is_24_structure = ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ) + is_weight_only = weight_args is not None and input_args is None + + # w4a16, w4a4, fp4 + if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + if weight_args.strategy in ( + QuantizationStrategy.TENSOR_GROUP.value, + QuantizationStrategy.CHANNEL.value, + QuantizationStrategy.GROUP.value, + ): + return CompressionFormat.nvfp4_pack_quantized + else: + if is_weight_only: + return CompressionFormat.naive_quantized + return CompressionFormat.float_quantized + + if is_weight_only: # w4a16 and w8a16, int + is_valid_pack = ( + weight_args.num_bits in [4, 8] + and weight_args.type == QuantizationType.INT.value + ) + if not is_valid_pack: # packing only valid for int4 and int 8 + return CompressionFormat.naive_quantized + if is_24_structure: + if ( + weight_args.strategy is not QuantizationStrategy.CHANNEL.value + and weight_args.strategy is not QuantizationStrategy.GROUP.value + ): + # marlin24 kernel only applicable for channel/group quantization + return CompressionFormat.pack_quantized + return CompressionFormat.marlin_24 + return CompressionFormat.pack_quantized + + else: # w8a8 float and int + if ( + weight_args.type == QuantizationType.FLOAT.value + and weight_args.num_bits == 8 + ): + return CompressionFormat.float_quantized + if weight_args.type == QuantizationType.INT.value: + return CompressionFormat.int_quantized + + return CompressionFormat.naive_quantized + + +def set_per_module_format( + module: torch.nn.Module, sparsity_structure: Optional[str] = None +): + """ + Determine and set the per module quantization format given quantization args + and sparsity structure. + + :param module: module which has its quantization inferred + :param sparisty_structure: optional sparsity applied to the module + + """ + weight_scheme = module.quantization_scheme.weights + input_scheme = module.quantization_scheme.input_activations + if weight_scheme is None: + return # no weight quant - nothing to compress + compression_format = _get_quant_compression_format( + input_scheme, weight_scheme, sparsity_structure + ) + + # If set, we check if it matches our inferred one + if module.quantization_scheme.format is not None: + # If it does not, warn the user + if module.quantization_scheme.format != compression_format.value: + logger.warning( + "The provided format for the module does not match the " + "inferred format. Compression may fail " + ) + else: + # If not set, we set ours + module.quantization_scheme.format = compression_format.value + + +def infer_and_set_per_module_quantization_format( + model: torch.nn.Module, + sparsity_structure: Optional[str] = None, +) -> Union[str, List[str]]: + """ + Infers the quantization format for a model based on its state and provided + compression arguments. Updates thhe quantization_scheme.format value + based on the inferred format. Returns the unique list of formats in the model + or None if empty list + + For a summary of the formats, see `docs/guides/compression_formats.md`. + + :param model: model to check for quantization + :param sparisty_structure: optional sparsity applied to the module + :return compression format appropriate for model + """ + unique_formats = [] + for submodule in model.modules(): + if is_module_quantized(submodule): + set_per_module_format(submodule, sparsity_structure) + if submodule.quantization_scheme.format not in unique_formats: + unique_formats.append(submodule.quantization_scheme.format) + + if len(unique_formats) > 0: + return unique_formats + return CompressionFormat.dense.value diff --git a/tests/test_configs/test_infer_quant.py b/tests/test_configs/test_infer_quant.py new file mode 100644 index 000000000..6fa02dff5 --- /dev/null +++ b/tests/test_configs/test_infer_quant.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import pytest +import torch +from compressed_tensors.config.format import ( + infer_and_set_per_module_quantization_format, +) +from compressed_tensors.quantization import preset_name_to_scheme + + +@pytest.mark.parametrize( + "preset,sparsity_structure,expected_format", + [ + ["W8A8", "unstructured", "int-quantized"], + ["W8A16", "unstructured", "pack-quantized"], + ["W8A16", "2:4", "marlin-24"], + ["W4A16", "unstructured", "pack-quantized"], + ["W4A16", "2:4", "marlin-24"], + ["FP8", "unstructured", "float-quantized"], + ], +) +def test_infer_quant_format(preset, sparsity_structure, expected_format): + quant_scheme = preset_name_to_scheme(preset, targets=["Linear"]) + + dummy_model = torch.nn.Sequential( + OrderedDict( + [ + ("fc1", torch.nn.Linear(8, 16, bias=True)), + ("fc2", torch.nn.Linear(16, 32, bias=True)), + ( + "block1", + torch.nn.Sequential( + OrderedDict( + [ + ("fc1", torch.nn.Linear(32, 16, bias=True)), + ("fc2", torch.nn.Linear(16, 8, bias=True)), + ] + ) + ), + ), + ] + ) + ) + + for _, module in dummy_model.named_modules(): + module.quantization_scheme = quant_scheme + + inferred_format = infer_and_set_per_module_quantization_format( + dummy_model, sparsity_structure=sparsity_structure + ) + assert inferred_format[0] == expected_format