From 7b2f3bce89ad0bdbb7b3009a4745dfa5e74d299c Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 21:12:16 +0000 Subject: [PATCH 01/10] Add `python-style` checks to pr ci Signed-off-by: Fynn Schmitt-Ulms --- .github/workflows/test-check.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index 37804dc4..db46606e 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -26,3 +26,21 @@ jobs: run: pip3 install .[dev,accelerate] - name: "๐Ÿ”ฌ Running tests" run: make test + + python-style: + runs-on: ubuntu-24.04 + steps: + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + - name: Set Env + run: | + pip3 install --upgrade pip && pip3 install --upgrade setuptools + - name: "โš™๏ธ Install dependencies" + run: pip3 install .[dev] + - name: "๐Ÿ”ฌ Running quality checks" + run: make quality From c8b26befe7c84191dd68659f916a0128a8db3077 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 23:50:46 +0000 Subject: [PATCH 02/10] Ignore auto-generated version.py file in copyright check Signed-off-by: Fynn Schmitt-Ulms --- utils/copyright.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/utils/copyright.py b/utils/copyright.py index 7cc68dd0..433bccc3 100644 --- a/utils/copyright.py +++ b/utils/copyright.py @@ -142,6 +142,9 @@ def _get_files(patterns: List[str]) -> List[str]: def _dont_copyright(file_path: str) -> bool: + if file_path.endswith("compressed_tensors/version.py"): + return True + with open(file_path, "r") as file: content = file.read() @@ -343,4 +346,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 38f81615c585f5e2e222ca2582ba7b1afc3aadc8 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 21:33:09 +0000 Subject: [PATCH 03/10] Run `make style` Signed-off-by: Fynn Schmitt-Ulms --- .../compressors/model_compressors/model_compressor.py | 4 ++-- .../compressors/quantized_compressors/nvfp4_quantized.py | 1 + .../compressors/sparse_quantized_compressors/marlin_24.py | 2 +- src/compressed_tensors/registry/registry.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f567e5a6..1272add8 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -195,7 +195,7 @@ def from_pretrained_model( @staticmethod def parse_sparsity_config( - compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"] + compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"], ) -> Union[Dict[str, Any], None]: """ Parse sparsity config from quantization/compression config. Sparsity @@ -215,7 +215,7 @@ def parse_sparsity_config( @staticmethod def parse_quantization_config( - compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"] + compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"], ) -> Union[Dict[str, Any], None]: """ Parse quantization config from quantization/compression config. The diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..5d521fbc 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -154,6 +154,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) + # reference: : https://github.com/vllm-project/vllm/pull/16362 def unpack_fp4_from_uint8( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 diff --git a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py index 7b8fea02..a5d6a6b7 100644 --- a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -48,7 +48,7 @@ class Marlin24Compressor(BaseCompressor): @staticmethod def validate_quant_compatability( - names_to_scheme: Dict[str, QuantizationScheme] + names_to_scheme: Dict[str, QuantizationScheme], ) -> bool: """ Checks if every quantized module in the model is compatible with Marlin24 diff --git a/src/compressed_tensors/registry/registry.py b/src/compressed_tensors/registry/registry.py index 39815a0b..3ad98d85 100644 --- a/src/compressed_tensors/registry/registry.py +++ b/src/compressed_tensors/registry/registry.py @@ -55,7 +55,7 @@ def standardize_lookup_name(name: str) -> str: def standardize_alias_name( - name: Union[None, str, List[str]] + name: Union[None, str, List[str]], ) -> Union[None, str, List[str]]: if name is None: return None From c49380a66d7849957c4e54c581c9630667ec6162 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 22:08:53 +0000 Subject: [PATCH 04/10] Fix line lengths Signed-off-by: Fynn Schmitt-Ulms --- .../model_compressors/model_compressor.py | 28 +++++++++---------- .../compressors/quantized_compressors/base.py | 3 +- .../quantized_compressors/nvfp4_quantized.py | 5 ++-- .../quantized_compressors/pack_quantized.py | 6 ++-- .../quantization/lifecycle/apply.py | 8 +++--- .../quantization/lifecycle/forward.py | 7 +++-- .../quantization/lifecycle/initialize.py | 13 +++++---- .../quantization/quant_args.py | 20 ++++++------- .../quantization/quant_scheme.py | 7 +++-- src/compressed_tensors/utils/match.py | 2 +- tests/test_utils/test_match.py | 2 +- 11 files changed, 53 insertions(+), 48 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 1272add8..c429120a 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -562,11 +562,12 @@ def decompress(self, model_path: str, model: Module): :param model_path: path to compressed weights :param model: pytorch model to load decompressed weights into - Note: decompress makes use of both _replace_sparsity_weights and _replace_weights - The variations in these methods are a result of the subtle variations between the sparsity - and quantization compressors. Specifically, quantization compressors return not just the - decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity - compressors only return the decompressed weight. + Note: decompress makes use of both _replace_sparsity_weights and + _replace_weights. The variations in these methods are a result of the subtle + variations between the sparsity and quantization compressors. Specifically, + quantization compressors return not just the decompressed weight, but the + quantization parameters (e.g scales, zero_point) whereas sparsity compressors + only return the decompressed weight. """ model_path = get_safetensors_folder(model_path) @@ -598,18 +599,17 @@ def decompress(self, model_path: str, model: Module): with override_quantization_status( self.quantization_config, QuantizationStatus.FROZEN ): - names_to_scheme = apply_quantization_config( model, self.quantization_config ) # Load activation scales/zp or any other quantization parameters - # Conditionally load the weight quantization parameters if we have a dense compressor - # Or if a sparsity compressor has already been applied + # Conditionally load the weight quantization parameters if we have a + # dense compressor or if a sparsity compressor has already been applied load_pretrained_quantization_parameters( model, model_path, - # TODO: all weight quantization params will be moved to the compressor in a follow-up - # including initialization + # TODO: all weight quantization params will be moved to the + # compressor in a follow-up including initialization load_weight_quantization=( sparse_decompressed or isinstance(self.quantization_compressor, DenseCompressor) @@ -695,7 +695,6 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module): :param model: The model whose weights are to be updated. """ for name, data in tqdm(dense_weight_generator, desc="Decompressing model"): - split_name = name.split(".") prefix, param_name = ".".join(split_name[:-1]), split_name[-1] module = operator.attrgetter(prefix)(model) @@ -731,9 +730,10 @@ def _replace_weights(self, dense_weight_generator, model: Module): for param_name, param_data in data.items(): if hasattr(module, param_name): # If compressed, will have an incorrect dtype for transformers >4.49 - # TODO: we can also just skip initialization of scales/zp if in decompression in init - # to be consistent with loading which happens later as well - # however, update_data does a good shape check - should be moved to the compressor + # TODO: we can also just skip initialization of scales/zp if in + # decompression in init to be consistent with loading which happens + # later as well however, update_data does a good shape check - + # should be moved to the compressor if param_name == "weight": delattr(module, param_name) requires_grad = param_data.dtype in ( diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 302107d5..5b27aef0 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -107,7 +107,8 @@ def compress( compressed_dict[name] = value.to(compression_device) continue - # compress values on meta if loading from meta otherwise on cpu (memory movement too expensive) + # compress values on meta if loading from meta otherwise on cpu (memory + # movement too expensive) module_path = prefix[:-1] if prefix.endswith(".") else prefix quant_args = names_to_scheme[module_path].weights compressed_values = self.compress_weight( diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5d521fbc..b41ffd1d 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -161,8 +161,9 @@ def unpack_fp4_from_uint8( ) -> torch.Tensor: """ Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values - (i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive - fp4 value). The bits represent an index, which are mapped to an fp4 value. + (i.e. first four bits correspond to one fp4 value, last four correspond to a + consecutive fp4 value). The bits represent an index, which are mapped to an fp4 + value. :param a: tensor to unpack :param m: original dim 0 size of the unpacked tensor diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index d5188d23..a7af3994 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -135,7 +135,8 @@ def compress_weight( compressed_dict["weight_shape"] = weight_shape compressed_dict["weight_packed"] = packed_weight - # We typically don't compress zp; apart from when using the packed_compressor and when storing group/channel zp + # We typically don't compress zp; apart from when using the packed_compressor + # and when storing group/channel zp if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, @@ -166,7 +167,8 @@ def decompress_weight( num_bits = quantization_args.num_bits unpacked = unpack_from_int32(weight, num_bits, original_shape) - # NOTE: this will fail decompression as we don't currently handle packed zp on decompression + # NOTE: this will fail decompression as we don't currently handle packed zp on + # decompression if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..94fc8871 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -73,14 +73,14 @@ def load_pretrained_quantization_parameters( Loads the quantization parameters (scale and zero point) from model_name_or_path to a model that has already been initialized with a quantization config. - NOTE: Will always load inputs/output parameters. - Will conditioanlly load weight parameters, if load_weight_quantization is set to True. + NOTE: Will always load inputs/output parameters. Will conditioanlly load weight + parameters, if load_weight_quantization is set to True. :param model: model to load pretrained quantization parameters to :param model_name_or_path: Hugging Face stub or local folder containing a quantized model, which is used to load quantization parameters - :param load_weight_quantization: whether or not the weight quantization parameters shoud - be laoded + :param load_weight_quantization: whether or not the weight quantization parameters + should be loaded """ model_path = get_safetensors_folder(model_name_or_path) mapping = get_quantization_parameter_to_path_mapping(model_path) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b82a4195..645e3938 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -200,7 +200,8 @@ def _process_quantization( q_min, q_max = calculate_range(args, x.device) group_size = args.group_size - # blockwise FP8: quantize per 2D block, supports block_structure for static block quant + # blockwise FP8: quantize per 2D block, supports block_structure for static block + # quantization if args.strategy == QuantizationStrategy.BLOCK: original_shape = x.shape rows, cols = x.shape[-2], x.shape[-1] @@ -209,8 +210,8 @@ def _process_quantization( # Ensure exact division (tensor dimensions must be divisible by block size) if rows % block_height != 0: raise ValueError( - f"Tensor height {rows} is not divisible by block_height {block_height}. " - f"Block quantization requires exact division." + f"Tensor height {rows} is not divisible by block_height {block_height}." + f" Block quantization requires exact division." ) if cols % block_width != 0: raise ValueError( diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b0c32439..c9430e9e 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,7 +17,7 @@ import math import warnings from enum import Enum -from typing import List, Optional +from typing import Optional import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -87,7 +87,6 @@ def initialize_module_for_quantization( _initialize_attn_scales(module) else: - if scheme.input_activations is not None: _initialize_scale_zero_point( module, @@ -183,7 +182,8 @@ def _initialize_scale_zero_point( num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) expected_shape = (weight_shape[0], max(num_groups, 1)) elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only for weights + # For block quantization, scale shape should match number of blocks - only + # for weights if quantization_args.block_structure is None: raise ValueError( "Block quantization requires block_structure to be specified" @@ -196,9 +196,10 @@ def _initialize_scale_zero_point( # Warn if dimensions don't divide evenly if rows % block_height != 0 or cols % block_width != 0: warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide evenly " - f"by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization quality.", + f"Block quantization: tensor shape {weight_shape} does not divide" + f"evenly by block structure {quantization_args.block_structure}. " + f"Some blocks will be incomplete which may affect quantization" + "quality.", UserWarning, ) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 53dbf88e..97220aa8 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -211,23 +211,21 @@ def validate_group(cls, value) -> Union[int, None]: def validate_block_structure(cls, value) -> Optional[List[int]]: if value is None: return value + invalid_block_structure_msg = ( + f"Invalid block_structure '{value}'. Must be a list of two ints" + " [rows, cols]." + ) # For backward compatibility, allow string format "2x4", "8x16", etc. if isinstance(value, str): try: return [int(x) for x in value.split("x")] except Exception: - raise ValueError( - f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." - ) + raise ValueError(invalid_block_structure_msg) if isinstance(value, (list, tuple)): if len(value) != 2 or not all(isinstance(v, int) for v in value): - raise ValueError( - f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." - ) + raise ValueError(invalid_block_structure_msg) return list(value) - raise ValueError( - f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." - ) + raise ValueError(invalid_block_structure_msg) @field_validator("strategy", mode="before") def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]: @@ -307,7 +305,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": ) if strategy not in supported_strategies: raise ValueError( - f"One of {supported_strategies} must be used for dynamic quantization" + f"One of {supported_strategies} must be used for dynamic quant." ) if ( @@ -322,7 +320,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": observer != "memoryless" ): # avoid annoying users with old configs warnings.warn( - "No observer is used for dynamic quantization, setting to None" + "No observer is used for dynamic quant., setting to None" ) observer = None else: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 1124b55f..988161d8 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -72,9 +72,10 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": ): warnings.warn( "Using GROUP strategy for both weights and input_activations " - f"with different group sizes ({weights.group_size} vs {inputs.group_size}) " - "may complicate fused kernel implementations. Consider using " - "TENSOR_GROUP strategy for both or matching group sizes.", + f"with different group sizes ({weights.group_size} vs " + f"{inputs.group_size}) may complicate fused kernel implementations. " + "Consider using TENSOR_GROUP strategy for both or matching group" + " sizes.", UserWarning, stacklevel=2, ) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 21ce4a0b..a7aafdab 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -110,7 +110,7 @@ def match_modules_set( Yields modules grouped with the same order and size as `targets`. Values are returned in order of `model.named_modules()` - For example, the following targets would yield module belonging to the following layers: + E.g. the following targets would yield module belonging to the following layers: ```python3 match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == ( ( diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 705676b9..2063aa9b 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -364,7 +364,7 @@ def test_module_set_ordering(self): for module_set in matches: # Check that modules are returned in target order (v, q, k) v_proj, q_proj, k_proj = module_set - # We can"t easily check the exact modules, but we can check they"re all Linear + # We can't easily check the exact modules, but can check they're all Linear assert all(isinstance(m, nn.Linear) for m in [v_proj, q_proj, k_proj]) def test_incomplete_set_error(self): From baa97abc08785c9f03f023cb42a3f1dd479fa96d Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 22:10:49 +0000 Subject: [PATCH 05/10] Remove unused imports Signed-off-by: Fynn Schmitt-Ulms --- .../compressors/model_compressors/model_compressor.py | 2 -- .../compressors/quantized_compressors/base.py | 1 - .../compressors/quantized_compressors/nvfp4_quantized.py | 3 --- .../compressors/quantized_compressors/pack_quantized.py | 1 - .../compressors/sparse_compressors/sparse_24_bitmask.py | 2 +- src/compressed_tensors/quantization/quant_scheme.py | 2 +- src/compressed_tensors/quantization/utils/helpers.py | 1 - src/compressed_tensors/transform/factory/hadamard.py | 4 ++-- src/compressed_tensors/transform/factory/matrix_multiply.py | 2 +- src/compressed_tensors/transform/utils/matrix.py | 2 +- src/compressed_tensors/utils/safetensors_load.py | 1 - tests/test_quantization/lifecycle/test_initialize.py | 1 - 12 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index c429120a..ac5b426d 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -42,7 +42,6 @@ load_pretrained_quantization_parameters, ) from compressed_tensors.quantization.lifecycle import expand_target_names -from compressed_tensors.quantization.utils import is_module_quantized from compressed_tensors.utils import ( align_module_device, delete_offload_parameter, @@ -390,7 +389,6 @@ def compress_model(self, model: Module): ) for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): - if prefix in module_to_scheme or prefix in sparse_compression_targets: module_device = get_execution_device(module) is_meta = module_device.type == "meta" diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 5b27aef0..f04624d6 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -24,7 +24,6 @@ get_nested_weight_mappings, merge_names, ) -from compressed_tensors.utils.safetensors_load import match_param_name from safetensors import safe_open from torch import Tensor from tqdm import tqdm diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index b41ffd1d..cf6cf932 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -15,7 +15,6 @@ from typing import Dict, Optional, Tuple -import numpy import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.quantized_compressors.base import ( @@ -71,7 +70,6 @@ def compress_weight( zero_point: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: - quantized_weight = quantize( x=weight, scale=scale, @@ -91,7 +89,6 @@ def decompress_weight( compressed_data: Dict[str, Tensor], quantization_args: Optional[QuantizationArgs] = None, ) -> torch.Tensor: - weight = compressed_data["weight_packed"] scale = compressed_data["weight_scale"] global_scale = compressed_data["weight_global_scale"] diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index a7af3994..e2ce3d24 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -14,7 +14,6 @@ import math from typing import Dict, Literal, Optional, Tuple, Union -import numpy as np import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.quantized_compressors.base import ( diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 275c70f2..f11d7b42 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, Generator, List, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from compressed_tensors.compressors.base import BaseCompressor diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 988161d8..b1a17ed0 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import List, Optional from compressed_tensors.quantization.quant_args import ( DynamicType, diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 42a6e19e..989c9724 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -29,7 +29,6 @@ from compressed_tensors.utils import deprecated from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module -from tqdm import tqdm __all__ = [ diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 02ebd89b..a481497e 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, Union +from typing import Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -26,7 +26,7 @@ from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("hadamard") diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 8b829451..4035d7a1 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -24,7 +24,7 @@ from compressed_tensors.utils import get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("random-matrix") diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index f353f8a2..92072857 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Tuple +from typing import Optional import torch from compressed_tensors.transform import TransformLocation diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 13a8eb48..cb2b913b 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -18,7 +18,6 @@ import struct from typing import Dict, Iterable, Optional, Tuple, Union -from safetensors import safe_open from torch import Tensor from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index faf56ba6..e613e399 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -24,7 +24,6 @@ QuantizationScheme, QuantizationStatus, QuantizationStrategy, - QuantizationType, ) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, From 03560cf27c77bec0df98f8b9b0690f967790858a Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 22:14:48 +0000 Subject: [PATCH 06/10] Fix explicit comparison to True/False Signed-off-by: Fynn Schmitt-Ulms --- tests/test_utils/test_match.py | 84 +++++++++++++++------------------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 2063aa9b..d168f725 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -70,43 +70,35 @@ class TestMatchName: def test_exact_match(self): """Test exact string matching""" - assert _match_name("layer1", "layer1") == True - assert _match_name("layer1", "layer2") == False - assert ( - _match_name( - "transformer.layers.0.self_attn.q_proj", - "transformer.layers.0.self_attn.q_proj", - ) - == True + assert _match_name("layer1", "layer1") + assert not _match_name("layer1", "layer2") + assert _match_name( + "transformer.layers.0.self_attn.q_proj", + "transformer.layers.0.self_attn.q_proj", ) def test_regex_match(self): """Test regex matching with "re:" prefix""" - assert _match_name("layer1", "re:layer.*") == True - assert _match_name("layer1", "re:^layer1$") == True - assert _match_name("layer1", "re:layer2") == False - assert ( - _match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") == True - ) - assert ( - _match_name( - "transformer.layers.0.self_attn.q_proj", - "re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$", - ) - == True + assert _match_name("layer1", "re:layer.*") + assert _match_name("layer1", "re:^layer1$") + assert not _match_name("layer1", "re:layer2") + assert _match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") + assert _match_name( + "transformer.layers.0.self_attn.q_proj", + "re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$", ) def test_empty_strings(self): """Test edge cases with empty strings""" - assert _match_name("", "") == True - assert _match_name("layer1", "") == False - assert _match_name("", "layer1") == False + assert _match_name("", "") + assert not _match_name("layer1", "") + assert not _match_name("", "layer1") def test_regex_special_characters(self): """Test regex with special characters""" - assert _match_name("layer.1", "re:layer\\.1") == True - assert _match_name("layer.1", "re:layer.1") == True # . matches any char - assert _match_name("layer_1", "re:layer_1") == True + assert _match_name("layer.1", "re:layer\\.1") + assert _match_name("layer.1", "re:layer.1") # . matches any char + assert _match_name("layer_1", "re:layer_1") class TestMatchClass: @@ -115,32 +107,32 @@ class TestMatchClass: def test_direct_class_match(self): """Test matching direct class names""" linear = nn.Linear(10, 20) - assert _match_class(linear, "Linear") == True - assert _match_class(linear, "Conv2d") == False + assert _match_class(linear, "Linear") + assert not _match_class(linear, "Conv2d") norm = nn.LayerNorm(10) - assert _match_class(norm, "LayerNorm") == True - assert _match_class(norm, "BatchNorm1d") == False + assert _match_class(norm, "LayerNorm") + assert not _match_class(norm, "BatchNorm1d") def test_parent_class_match(self): """Test matching parent class names""" linear = nn.Linear(10, 20) - assert _match_class(linear, "Module") == True + assert _match_class(linear, "Module") conv = nn.Conv2d(3, 16, 3) - assert _match_class(conv, "Module") == True - assert _match_class(conv, "_ConvNd") == True + assert _match_class(conv, "Module") + assert _match_class(conv, "_ConvNd") def test_non_torch_module(self): """Test with non-torch modules""" regular_object = object() - assert _match_class(regular_object, "object") == False # not a torch.nn.Module + assert not _match_class(regular_object, "object") # not a torch.nn.Module def test_custom_module(self): """Test with custom module classes""" model = DummyModel() - assert _match_class(model, "DummyModel") == True - assert _match_class(model, "Module") == True + assert _match_class(model, "DummyModel") + assert _match_class(model, "Module") class TestIsMatch: @@ -149,27 +141,27 @@ class TestIsMatch: def test_name_match(self): """Test matching by name""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "layer1") == True - assert is_match("layer1", linear, "layer2") == False + assert is_match("layer1", linear, "layer1") + assert not is_match("layer1", linear, "layer2") def test_class_match(self): """Test matching by class""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "Linear") == True - assert is_match("layer1", linear, "Conv2d") == False + assert is_match("layer1", linear, "Linear") + assert not is_match("layer1", linear, "Conv2d") def test_combined_match(self): """Test that either name or class match works""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "layer1") == True # name match - assert is_match("layer1", linear, "Linear") == True # class match - assert is_match("layer1", linear, "layer2") == False # no match + assert is_match("layer1", linear, "layer1") # name match + assert is_match("layer1", linear, "Linear") # class match + assert not is_match("layer1", linear, "layer2") # no match def test_regex_in_name_match(self): """Test regex matching in name""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "re:layer.*") == True - assert is_match("layer1", linear, "re:conv.*") == False + assert is_match("layer1", linear, "re:layer.*") + assert not is_match("layer1", linear, "re:conv.*") def test_internal_module_match(self): """Test not matching internal modules""" @@ -178,7 +170,7 @@ class InternalLinear(InternalModule, nn.Linear): pass linear = InternalLinear(10, 20) - assert is_match("layer1", linear, "re:layer.*") == False + assert not is_match("layer1", linear, "re:layer.*") class TestMatchNamedModules: From b5b9b2b4f3595b58005275fe95feabdcaa501e9e Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 22:17:53 +0000 Subject: [PATCH 07/10] Fix misc flake8 errors Signed-off-by: Fynn Schmitt-Ulms --- src/compressed_tensors/utils/offload.py | 1 - tests/test_examples/test_bitmask_compression_ipynb.py | 2 +- tests/test_transform/test_transform_config.py | 2 +- tests/test_utils/test_offload.py | 8 ++++---- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index cf3707fd..55157b4a 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -282,7 +282,6 @@ def disable_hf_hook(module: torch.nn.Module): hooks = {} def collect_hooks(module): - nonlocal hooks if hasattr(module, "_hf_hook"): hooks[module] = module._hf_hook remove_hook_from_module(module) diff --git a/tests/test_examples/test_bitmask_compression_ipynb.py b/tests/test_examples/test_bitmask_compression_ipynb.py index 12f196f6..a813dce8 100644 --- a/tests/test_examples/test_bitmask_compression_ipynb.py +++ b/tests/test_examples/test_bitmask_compression_ipynb.py @@ -16,7 +16,7 @@ nbformat = pytest.importorskip("nbformat") -from nbconvert.preprocessors import ExecutePreprocessor +from nbconvert.preprocessors import ExecutePreprocessor # noqa: E402 @pytest.mark.skip( diff --git a/tests/test_transform/test_transform_config.py b/tests/test_transform/test_transform_config.py index 52167cfd..dc725743 100644 --- a/tests/test_transform/test_transform_config.py +++ b/tests/test_transform/test_transform_config.py @@ -66,6 +66,6 @@ def test_multiple_groups(): type="hadamard", apply=[linear_args_2], ) - config = TransformConfig( + _ = TransformConfig( config_groups={"transform_0": scheme_1, "transform_1": scheme_2} ) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 1fce49b3..aed0186b 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -441,8 +441,8 @@ def test_delete_offload_module(exec_device): register_offload_module(model.linear, "child", child) delete_offload_module(model, "child") delete_offload_module(model.linear, "child") - assert not child in model.children() - assert not child in model.linear.children() + assert child not in model.children() + assert child not in model.linear.children() # with offloading model = ExampleModel() @@ -452,8 +452,8 @@ def test_delete_offload_module(exec_device): register_offload_module(model.linear, "child", child) delete_offload_module(model, "child") delete_offload_module(model.linear, "child") - assert not child in model.children() - assert not child in model.linear.children() + assert child not in model.children() + assert child not in model.linear.children() @requires_gpu From a4e3b48f6f75d6674e66230748b68de9520827b9 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 30 Jul 2025 14:04:46 +0000 Subject: [PATCH 08/10] Skip auto-generated `version.py` file when running isort Signed-off-by: Fynn Schmitt-Ulms --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index cf87c555..4f8387dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER +skip = src/compressed_tensors/version.py line_length = 88 lines_after_imports = 2 From 81dba6a6c6d966132fd49e06e8ea62e4ce2382ff Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Thu, 31 Jul 2025 13:17:39 +0000 Subject: [PATCH 09/10] Move quality check to separate workflow file and fix error msg Signed-off-by: Fynn Schmitt-Ulms --- .github/workflows/quality-check.yaml | 27 +++++++++++++++++++ .github/workflows/test-check.yaml | 17 ------------ .../quantization/quant_args.py | 18 ++++++++----- 3 files changed, 38 insertions(+), 24 deletions(-) create mode 100644 .github/workflows/quality-check.yaml diff --git a/.github/workflows/quality-check.yaml b/.github/workflows/quality-check.yaml new file mode 100644 index 00000000..05cadf6e --- /dev/null +++ b/.github/workflows/quality-check.yaml @@ -0,0 +1,27 @@ +name: Quality Checks +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + quality-check: + runs-on: ubuntu-24.04 + steps: + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + - name: Set Env + run: | + pip3 install --upgrade pip && pip3 install --upgrade setuptools + - name: "โš™๏ธ Install dependencies" + run: pip3 install .[dev] + - name: "๐Ÿงน Running quality checks" + run: make quality \ No newline at end of file diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index db46606e..ce8053ff 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -27,20 +27,3 @@ jobs: - name: "๐Ÿ”ฌ Running tests" run: make test - python-style: - runs-on: ubuntu-24.04 - steps: - - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - fetch-tags: true - - name: Set Env - run: | - pip3 install --upgrade pip && pip3 install --upgrade setuptools - - name: "โš™๏ธ Install dependencies" - run: pip3 install .[dev] - - name: "๐Ÿ”ฌ Running quality checks" - run: make quality diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 97220aa8..8f478ef9 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -211,21 +211,25 @@ def validate_group(cls, value) -> Union[int, None]: def validate_block_structure(cls, value) -> Optional[List[int]]: if value is None: return value - invalid_block_structure_msg = ( - f"Invalid block_structure '{value}'. Must be a list of two ints" - " [rows, cols]." - ) # For backward compatibility, allow string format "2x4", "8x16", etc. if isinstance(value, str): try: return [int(x) for x in value.split("x")] except Exception: - raise ValueError(invalid_block_structure_msg) + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of ints " + "[rows, cols]." + ) if isinstance(value, (list, tuple)): if len(value) != 2 or not all(isinstance(v, int) for v in value): - raise ValueError(invalid_block_structure_msg) + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of ints " + "[rows, cols]." + ) return list(value) - raise ValueError(invalid_block_structure_msg) + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of ints [rows, cols]." + ) @field_validator("strategy", mode="before") def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]: From 761d399d21d19ecdbe821de03ffae3eafa4665ea Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Thu, 7 Aug 2025 13:55:32 -0400 Subject: [PATCH 10/10] Add 'release/*' branch triggers to ci quality and tests Signed-off-by: Fynn Schmitt-Ulms --- .github/workflows/quality-check.yaml | 2 ++ .github/workflows/test-check.yaml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.github/workflows/quality-check.yaml b/.github/workflows/quality-check.yaml index 05cadf6e..03ec17b6 100644 --- a/.github/workflows/quality-check.yaml +++ b/.github/workflows/quality-check.yaml @@ -3,9 +3,11 @@ on: push: branches: - main + - 'release/*' pull_request: branches: - main + - 'release/*' jobs: quality-check: diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index ce8053ff..bce5bf47 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -4,9 +4,11 @@ on: push: branches: - main + - 'release/*' pull_request: branches: - main + - 'release/*' jobs: python-tests: