From 211fc85b7449ba03b0c0b977139d1656a6ad040d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 29 Aug 2025 16:24:56 +0000 Subject: [PATCH 1/7] add format infer code --- src/compressed_tensors/config/__init__.py | 1 + src/compressed_tensors/config/format.py | 114 ++++++++++++++++++++++ tests/test_configs/test_infer_quant.py | 50 ++++++++++ 3 files changed, 165 insertions(+) create mode 100644 src/compressed_tensors/config/format.py create mode 100644 tests/test_configs/test_infer_quant.py diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index 582b8a9e1..35c479596 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -17,3 +17,4 @@ from .dense import * from .sparse_24_bitmask import * from .sparse_bitmask import * +from .format import * diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py new file mode 100644 index 000000000..f583533af --- /dev/null +++ b/src/compressed_tensors/config/format.py @@ -0,0 +1,114 @@ +from typing import List, Optional + +from compressed_tensors import CompressionFormat +from compressed_tensors.config import SparsityStructure +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) +from compressed_tensors.quantization.utils import is_module_quantized +from loguru import logger + +__all__ = ["infer_and_set_per_module_quantization_format"] + + +def _get_quant_compression_format( + input_args: QuantizationArgs, + weight_args: QuantizationArgs, + sparsity_structure: Optional[str] = None, +): + is_24_structure = ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ) + is_weight_only = weight_args is not None and input_args is None + + if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + return CompressionFormat.nvfp4_pack_quantized + + if is_weight_only: # w4a16 and w8a16 + is_valid_pack = ( + weight_args.num_bits in [4, 8] + and weight_args.type == QuantizationType.INT.value + ) + if not is_valid_pack: # packing only valid for int4 and int 8 + return CompressionFormat.naive_quantized + if is_24_structure: + if ( + weight_args.strategy is not QuantizationStrategy.CHANNEL.value + and weight_args.strategy is not QuantizationStrategy.GROUP.value + ): + # marlin24 kernel only applicable for channel/group quantization + return CompressionFormat.pack_quantized + return CompressionFormat.marlin_24 + return CompressionFormat.pack_quantized + + else: # w8a8 float and int + if ( + weight_args.type == QuantizationType.FLOAT.value + and weight_args.num_bits == 8 + ): + return CompressionFormat.float_quantized + if weight_args.type == QuantizationType.INT.value: + return CompressionFormat.int_quantized + + return CompressionFormat.naive_quantized + + +def infer_and_set_per_module_quantization_format( + model, + quantization_format: Optional[str] = None, + save_compressed: bool = False, + sparsity_structure: Optional[str] = None, +) -> Optional[List[str]]: + """ + Infers the quantization format for a model based on its state and provided + compression arguments. Also updates thhe quantization_scheme.format value + based on the inferred format. Returns the unique list of formats in the model + or None if empty list + + For a summary of the formats, see `docs/guides/compression_formats.md`. + + :param model: model to check for quantization, if the model is not quantized no + quantization format is returned + :param quantization_format: user provided quantization format, supercedes any + inferred quantization format + :param save_compressed: used to infer a quantization format if None is provided + :return compression format appropriate for model + """ + + if not save_compressed: + return None + + if quantization_format: + return [quantization_format] + + unique_formats = [] + for submodule in model.modules(): + if is_module_quantized(submodule): + weight_scheme = submodule.quantization_scheme.weights + input_scheme = submodule.quantization_scheme.input_activations + if weight_scheme is None: + continue # no weight quant - nothing to compress + compression_format = _get_quant_compression_format( + input_scheme, weight_scheme, sparsity_structure + ) + + # If set, we check if it matches our inferred one + if submodule.quantization_scheme.format is not None: + # If it does not, warn the user + if submodule.quantization_scheme.format != compression_format.value: + logger.warning( + "The provided format for the module does not match the " + "inferred format. Compression may fail " + ) + else: + # If not set, we set ours + submodule.quantization_scheme.format = compression_format.value + + if submodule.quantization_scheme.format not in unique_formats: + unique_formats.append(submodule.quantization_scheme.format) + + if len(unique_formats) > 0: + return unique_formats + return None diff --git a/tests/test_configs/test_infer_quant.py b/tests/test_configs/test_infer_quant.py new file mode 100644 index 000000000..866a37ddd --- /dev/null +++ b/tests/test_configs/test_infer_quant.py @@ -0,0 +1,50 @@ +import pytest +from compressed_tensors.quantization import preset_name_to_scheme + +from compressed_tensors.config.formats import ( + infer_and_set_per_module_quantization_format, +) +import torch + + +@pytest.mark.parametrize( + "preset,sparsity_structure,expected_format", + [ + ["W8A8", "unstructured", "int-quantized"], + ["W8A16", "unstructured", "pack-quantized"], + ["W8A16", "2:4", "marlin-24"], + ["W4A16", "unstructured", "pack-quantized"], + ["W4A16", "2:4", "marlin-24"], + ["FP8", "unstructured", "float-quantized"], + ], +) +def test_infer_quant_format(preset, sparsity_structure, expected_format): + quant_scheme = preset_name_to_scheme(preset, targets=["Linear"]) + + dummy_model = torch.nn.Sequential( + torch.nn.OrderedDict( + [ + ("fc1", torch.nnLinear(8, 16, bias=True)), + ("fc2", torch.nn.Linear(16, 32, bias=True)), + ( + "block1", + torch.nn.Sequential( + torch.nn.OrderedDict( + [ + ("fc1", torch.nn.Linear(32, 16, bias=True)), + ("fc2", torch.nn.Linear(16, 8, bias=True)), + ] + ) + ), + ), + ] + ) + ) + + for _, module in dummy_model.named_modules(): + module.quantization_scheme = quant_scheme + + inferred_format = infer_and_set_per_module_quantization_format( + dummy_model, save_compressed=True, sparsity_structure=sparsity_structure + ) + assert inferred_format[0] == expected_format \ No newline at end of file From a3bb4dd56d05cd00bbbd09081d98aa4cfb24f292 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 29 Aug 2025 19:26:27 +0000 Subject: [PATCH 2/7] update --- src/compressed_tensors/config/__init__.py | 2 +- src/compressed_tensors/config/format.py | 18 +++++++++++-- tests/test_configs/test_infer_quant.py | 31 +++++++++++++++++------ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/config/__init__.py b/src/compressed_tensors/config/__init__.py index 35c479596..62abb7938 100644 --- a/src/compressed_tensors/config/__init__.py +++ b/src/compressed_tensors/config/__init__.py @@ -15,6 +15,6 @@ # flake8: noqa from .base import * from .dense import * +from .format import * from .sparse_24_bitmask import * from .sparse_bitmask import * -from .format import * diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index f583533af..516b997d0 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -1,7 +1,20 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import List, Optional -from compressed_tensors import CompressionFormat -from compressed_tensors.config import SparsityStructure +from compressed_tensors.config import CompressionFormat, SparsityStructure from compressed_tensors.quantization import ( QuantizationArgs, QuantizationStrategy, @@ -10,6 +23,7 @@ from compressed_tensors.quantization.utils import is_module_quantized from loguru import logger + __all__ = ["infer_and_set_per_module_quantization_format"] diff --git a/tests/test_configs/test_infer_quant.py b/tests/test_configs/test_infer_quant.py index 866a37ddd..209e1b1ec 100644 --- a/tests/test_configs/test_infer_quant.py +++ b/tests/test_configs/test_infer_quant.py @@ -1,10 +1,25 @@ -import pytest -from compressed_tensors.quantization import preset_name_to_scheme +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from compressed_tensors.config.formats import ( +from collections import OrderedDict + +import pytest +import torch +from compressed_tensors.config.format import ( infer_and_set_per_module_quantization_format, ) -import torch +from compressed_tensors.quantization import preset_name_to_scheme @pytest.mark.parametrize( @@ -22,14 +37,14 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): quant_scheme = preset_name_to_scheme(preset, targets=["Linear"]) dummy_model = torch.nn.Sequential( - torch.nn.OrderedDict( + OrderedDict( [ - ("fc1", torch.nnLinear(8, 16, bias=True)), + ("fc1", torch.nn.Linear(8, 16, bias=True)), ("fc2", torch.nn.Linear(16, 32, bias=True)), ( "block1", torch.nn.Sequential( - torch.nn.OrderedDict( + OrderedDict( [ ("fc1", torch.nn.Linear(32, 16, bias=True)), ("fc2", torch.nn.Linear(16, 8, bias=True)), @@ -47,4 +62,4 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): inferred_format = infer_and_set_per_module_quantization_format( dummy_model, save_compressed=True, sparsity_structure=sparsity_structure ) - assert inferred_format[0] == expected_format \ No newline at end of file + assert inferred_format[0] == expected_format From 5174978af306a5eff3ced157e374a8edfe82c2ce Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 29 Aug 2025 20:32:57 +0000 Subject: [PATCH 3/7] update --- src/compressed_tensors/config/format.py | 87 ++++++++++++++----------- tests/test_configs/test_infer_quant.py | 2 +- 2 files changed, 51 insertions(+), 38 deletions(-) diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index 516b997d0..5df55af4c 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -14,6 +14,7 @@ from typing import List, Optional +import torch from compressed_tensors.config import CompressionFormat, SparsityStructure from compressed_tensors.quantization import ( QuantizationArgs, @@ -31,7 +32,18 @@ def _get_quant_compression_format( input_args: QuantizationArgs, weight_args: QuantizationArgs, sparsity_structure: Optional[str] = None, -): +) -> CompressionFormat: + """ + Using the weight and input quantization args as well as an optional + sparsity structure, determine the compression format that should be + applied to a given module + + :param input_args: input quantization parameters + :param weight_args: weight quantization parameters + :param sparsity_structure: optional (global) modle sparsity + structure + :return CompresssionFormat for the module + """ is_24_structure = ( SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ) @@ -69,57 +81,58 @@ def _get_quant_compression_format( return CompressionFormat.naive_quantized +def set_per_module_format( + module: torch.nn.Module, sparsity_structure: Optional[str] = None +): + """ + Determine and set the per module quantization format given quantization args + and sparsity structure. + + :param module: module which has its quantization inferred + :param sparisty_structure: optional sparsity applied to the module + + """ + weight_scheme = module.quantization_scheme.weights + input_scheme = module.quantization_scheme.input_activations + if weight_scheme is None: + return # no weight quant - nothing to compress + compression_format = _get_quant_compression_format( + input_scheme, weight_scheme, sparsity_structure + ) + + # If set, we check if it matches our inferred one + if module.quantization_scheme.format is not None: + # If it does not, warn the user + if module.quantization_scheme.format != compression_format.value: + logger.warning( + "The provided format for the module does not match the " + "inferred format. Compression may fail " + ) + else: + # If not set, we set ours + module.quantization_scheme.format = compression_format.value + + def infer_and_set_per_module_quantization_format( - model, - quantization_format: Optional[str] = None, - save_compressed: bool = False, + model: torch.nn.Module, sparsity_structure: Optional[str] = None, ) -> Optional[List[str]]: """ Infers the quantization format for a model based on its state and provided - compression arguments. Also updates thhe quantization_scheme.format value + compression arguments. Updates thhe quantization_scheme.format value based on the inferred format. Returns the unique list of formats in the model or None if empty list For a summary of the formats, see `docs/guides/compression_formats.md`. - :param model: model to check for quantization, if the model is not quantized no - quantization format is returned - :param quantization_format: user provided quantization format, supercedes any - inferred quantization format - :param save_compressed: used to infer a quantization format if None is provided + :param model: model to check for quantization + :param sparisty_structure: optional sparsity applied to the module :return compression format appropriate for model """ - - if not save_compressed: - return None - - if quantization_format: - return [quantization_format] - unique_formats = [] for submodule in model.modules(): if is_module_quantized(submodule): - weight_scheme = submodule.quantization_scheme.weights - input_scheme = submodule.quantization_scheme.input_activations - if weight_scheme is None: - continue # no weight quant - nothing to compress - compression_format = _get_quant_compression_format( - input_scheme, weight_scheme, sparsity_structure - ) - - # If set, we check if it matches our inferred one - if submodule.quantization_scheme.format is not None: - # If it does not, warn the user - if submodule.quantization_scheme.format != compression_format.value: - logger.warning( - "The provided format for the module does not match the " - "inferred format. Compression may fail " - ) - else: - # If not set, we set ours - submodule.quantization_scheme.format = compression_format.value - + set_per_module_format(submodule, sparsity_structure) if submodule.quantization_scheme.format not in unique_formats: unique_formats.append(submodule.quantization_scheme.format) diff --git a/tests/test_configs/test_infer_quant.py b/tests/test_configs/test_infer_quant.py index 209e1b1ec..6fa02dff5 100644 --- a/tests/test_configs/test_infer_quant.py +++ b/tests/test_configs/test_infer_quant.py @@ -60,6 +60,6 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): module.quantization_scheme = quant_scheme inferred_format = infer_and_set_per_module_quantization_format( - dummy_model, save_compressed=True, sparsity_structure=sparsity_structure + dummy_model, sparsity_structure=sparsity_structure ) assert inferred_format[0] == expected_format From 9bd80bb2843b774bd4c493770fa1b3c88a7cb28b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 29 Aug 2025 20:40:21 +0000 Subject: [PATCH 4/7] add loguru --- setup.py | 2 +- src/compressed_tensors/config/format.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 58a1105f4..d60b8d8ce 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ def _setup_packages() -> List: ) def _setup_install_requires() -> List: - return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict"] + return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict", "loguru"] def _setup_extras() -> Dict: return { diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index 5df55af4c..79ba7f5c8 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -32,10 +32,10 @@ def _get_quant_compression_format( input_args: QuantizationArgs, weight_args: QuantizationArgs, sparsity_structure: Optional[str] = None, -) -> CompressionFormat: +) -> CompressionFormat: """ Using the weight and input quantization args as well as an optional - sparsity structure, determine the compression format that should be + sparsity structure, determine the compression format that should be applied to a given module :param input_args: input quantization parameters From 5bf3212d31af6d74bc6ecf0e2b77b0c4b630d689 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 29 Aug 2025 20:57:18 +0000 Subject: [PATCH 5/7] use dense not None --- src/compressed_tensors/config/format.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index 79ba7f5c8..ab757bef8 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Union import torch from compressed_tensors.config import CompressionFormat, SparsityStructure @@ -116,7 +116,7 @@ def set_per_module_format( def infer_and_set_per_module_quantization_format( model: torch.nn.Module, sparsity_structure: Optional[str] = None, -) -> Optional[List[str]]: +) -> Union[str, List[str]]: """ Infers the quantization format for a model based on its state and provided compression arguments. Updates thhe quantization_scheme.format value @@ -138,4 +138,4 @@ def infer_and_set_per_module_quantization_format( if len(unique_formats) > 0: return unique_formats - return None + return CompressionFormat.dense.value From 25acb7d3bf1c88ebde85aa4c32455a1a9be5117d Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 8 Sep 2025 21:30:15 +0000 Subject: [PATCH 6/7] update --- src/compressed_tensors/config/format.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index ab757bef8..dde62941f 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -49,10 +49,19 @@ def _get_quant_compression_format( ) is_weight_only = weight_args is not None and input_args is None + # w4a16, w4a4, fp4 if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: - return CompressionFormat.nvfp4_pack_quantized + if weight_args.strategy in ( + QuantizationStrategy.TENSOR_GROUP.value, + QuantizationStrategy.GROUP.value, + ): + return CompressionFormat.nvfp4_pack_quantized + else: + if is_weight_only: + return CompressionFormat.naive_quantized + return CompressionFormat.float_quantized - if is_weight_only: # w4a16 and w8a16 + if is_weight_only: # w4a16 and w8a16, int is_valid_pack = ( weight_args.num_bits in [4, 8] and weight_args.type == QuantizationType.INT.value From 27b317eddf53e428d52128a241d5e18aa84997fe Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 8 Sep 2025 23:58:57 +0000 Subject: [PATCH 7/7] update --- src/compressed_tensors/config/format.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index dde62941f..39a8d812c 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -53,12 +53,13 @@ def _get_quant_compression_format( if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: if weight_args.strategy in ( QuantizationStrategy.TENSOR_GROUP.value, + QuantizationStrategy.CHANNEL.value, QuantizationStrategy.GROUP.value, ): return CompressionFormat.nvfp4_pack_quantized else: if is_weight_only: - return CompressionFormat.naive_quantized + return CompressionFormat.naive_quantized return CompressionFormat.float_quantized if is_weight_only: # w4a16 and w8a16, int