diff --git a/.github/workflows/quality-check.yaml b/.github/workflows/quality-check.yaml new file mode 100644 index 00000000..03ec17b6 --- /dev/null +++ b/.github/workflows/quality-check.yaml @@ -0,0 +1,29 @@ +name: Quality Checks +on: + push: + branches: + - main + - 'release/*' + pull_request: + branches: + - main + - 'release/*' + +jobs: + quality-check: + runs-on: ubuntu-24.04 + steps: + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + - name: Set Env + run: | + pip3 install --upgrade pip && pip3 install --upgrade setuptools + - name: "⚙️ Install dependencies" + run: pip3 install .[dev] + - name: "🧹 Running quality checks" + run: make quality \ No newline at end of file diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index 37804dc4..bce5bf47 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -4,9 +4,11 @@ on: push: branches: - main + - 'release/*' pull_request: branches: - main + - 'release/*' jobs: python-tests: @@ -26,3 +28,4 @@ jobs: run: pip3 install .[dev,accelerate] - name: "🔬 Running tests" run: make test + diff --git a/setup.cfg b/setup.cfg index cf87c555..4f8387dd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER +skip = src/compressed_tensors/version.py line_length = 88 lines_after_imports = 2 diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f567e5a6..ac5b426d 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -42,7 +42,6 @@ load_pretrained_quantization_parameters, ) from compressed_tensors.quantization.lifecycle import expand_target_names -from compressed_tensors.quantization.utils import is_module_quantized from compressed_tensors.utils import ( align_module_device, delete_offload_parameter, @@ -195,7 +194,7 @@ def from_pretrained_model( @staticmethod def parse_sparsity_config( - compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"] + compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"], ) -> Union[Dict[str, Any], None]: """ Parse sparsity config from quantization/compression config. Sparsity @@ -215,7 +214,7 @@ def parse_sparsity_config( @staticmethod def parse_quantization_config( - compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"] + compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"], ) -> Union[Dict[str, Any], None]: """ Parse quantization config from quantization/compression config. The @@ -390,7 +389,6 @@ def compress_model(self, model: Module): ) for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): - if prefix in module_to_scheme or prefix in sparse_compression_targets: module_device = get_execution_device(module) is_meta = module_device.type == "meta" @@ -562,11 +560,12 @@ def decompress(self, model_path: str, model: Module): :param model_path: path to compressed weights :param model: pytorch model to load decompressed weights into - Note: decompress makes use of both _replace_sparsity_weights and _replace_weights - The variations in these methods are a result of the subtle variations between the sparsity - and quantization compressors. Specifically, quantization compressors return not just the - decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity - compressors only return the decompressed weight. + Note: decompress makes use of both _replace_sparsity_weights and + _replace_weights. The variations in these methods are a result of the subtle + variations between the sparsity and quantization compressors. Specifically, + quantization compressors return not just the decompressed weight, but the + quantization parameters (e.g scales, zero_point) whereas sparsity compressors + only return the decompressed weight. """ model_path = get_safetensors_folder(model_path) @@ -598,18 +597,17 @@ def decompress(self, model_path: str, model: Module): with override_quantization_status( self.quantization_config, QuantizationStatus.FROZEN ): - names_to_scheme = apply_quantization_config( model, self.quantization_config ) # Load activation scales/zp or any other quantization parameters - # Conditionally load the weight quantization parameters if we have a dense compressor - # Or if a sparsity compressor has already been applied + # Conditionally load the weight quantization parameters if we have a + # dense compressor or if a sparsity compressor has already been applied load_pretrained_quantization_parameters( model, model_path, - # TODO: all weight quantization params will be moved to the compressor in a follow-up - # including initialization + # TODO: all weight quantization params will be moved to the + # compressor in a follow-up including initialization load_weight_quantization=( sparse_decompressed or isinstance(self.quantization_compressor, DenseCompressor) @@ -695,7 +693,6 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module): :param model: The model whose weights are to be updated. """ for name, data in tqdm(dense_weight_generator, desc="Decompressing model"): - split_name = name.split(".") prefix, param_name = ".".join(split_name[:-1]), split_name[-1] module = operator.attrgetter(prefix)(model) @@ -731,9 +728,10 @@ def _replace_weights(self, dense_weight_generator, model: Module): for param_name, param_data in data.items(): if hasattr(module, param_name): # If compressed, will have an incorrect dtype for transformers >4.49 - # TODO: we can also just skip initialization of scales/zp if in decompression in init - # to be consistent with loading which happens later as well - # however, update_data does a good shape check - should be moved to the compressor + # TODO: we can also just skip initialization of scales/zp if in + # decompression in init to be consistent with loading which happens + # later as well however, update_data does a good shape check - + # should be moved to the compressor if param_name == "weight": delattr(module, param_name) requires_grad = param_data.dtype in ( diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 302107d5..f04624d6 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -24,7 +24,6 @@ get_nested_weight_mappings, merge_names, ) -from compressed_tensors.utils.safetensors_load import match_param_name from safetensors import safe_open from torch import Tensor from tqdm import tqdm @@ -107,7 +106,8 @@ def compress( compressed_dict[name] = value.to(compression_device) continue - # compress values on meta if loading from meta otherwise on cpu (memory movement too expensive) + # compress values on meta if loading from meta otherwise on cpu (memory + # movement too expensive) module_path = prefix[:-1] if prefix.endswith(".") else prefix quant_args = names_to_scheme[module_path].weights compressed_values = self.compress_weight( diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..cf6cf932 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -15,7 +15,6 @@ from typing import Dict, Optional, Tuple -import numpy import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.quantized_compressors.base import ( @@ -71,7 +70,6 @@ def compress_weight( zero_point: Optional[torch.Tensor] = None, g_idx: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: - quantized_weight = quantize( x=weight, scale=scale, @@ -91,7 +89,6 @@ def decompress_weight( compressed_data: Dict[str, Tensor], quantization_args: Optional[QuantizationArgs] = None, ) -> torch.Tensor: - weight = compressed_data["weight_packed"] scale = compressed_data["weight_scale"] global_scale = compressed_data["weight_global_scale"] @@ -154,14 +151,16 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) + # reference: : https://github.com/vllm-project/vllm/pull/16362 def unpack_fp4_from_uint8( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 ) -> torch.Tensor: """ Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values - (i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive - fp4 value). The bits represent an index, which are mapped to an fp4 value. + (i.e. first four bits correspond to one fp4 value, last four correspond to a + consecutive fp4 value). The bits represent an index, which are mapped to an fp4 + value. :param a: tensor to unpack :param m: original dim 0 size of the unpacked tensor diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index d5188d23..e2ce3d24 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -14,7 +14,6 @@ import math from typing import Dict, Literal, Optional, Tuple, Union -import numpy as np import torch from compressed_tensors.compressors.base import BaseCompressor from compressed_tensors.compressors.quantized_compressors.base import ( @@ -135,7 +134,8 @@ def compress_weight( compressed_dict["weight_shape"] = weight_shape compressed_dict["weight_packed"] = packed_weight - # We typically don't compress zp; apart from when using the packed_compressor and when storing group/channel zp + # We typically don't compress zp; apart from when using the packed_compressor + # and when storing group/channel zp if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, @@ -166,7 +166,8 @@ def decompress_weight( num_bits = quantization_args.num_bits unpacked = unpack_from_int32(weight, num_bits, original_shape) - # NOTE: this will fail decompression as we don't currently handle packed zp on decompression + # NOTE: this will fail decompression as we don't currently handle packed zp on + # decompression if not quantization_args.symmetric and quantization_args.strategy in [ QuantizationStrategy.GROUP.value, QuantizationStrategy.CHANNEL.value, diff --git a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py index 275c70f2..f11d7b42 100644 --- a/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +++ b/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, Generator, List, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from compressed_tensors.compressors.base import BaseCompressor diff --git a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py index 7b8fea02..a5d6a6b7 100644 --- a/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +++ b/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py @@ -48,7 +48,7 @@ class Marlin24Compressor(BaseCompressor): @staticmethod def validate_quant_compatability( - names_to_scheme: Dict[str, QuantizationScheme] + names_to_scheme: Dict[str, QuantizationScheme], ) -> bool: """ Checks if every quantized module in the model is compatible with Marlin24 diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..94fc8871 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -73,14 +73,14 @@ def load_pretrained_quantization_parameters( Loads the quantization parameters (scale and zero point) from model_name_or_path to a model that has already been initialized with a quantization config. - NOTE: Will always load inputs/output parameters. - Will conditioanlly load weight parameters, if load_weight_quantization is set to True. + NOTE: Will always load inputs/output parameters. Will conditioanlly load weight + parameters, if load_weight_quantization is set to True. :param model: model to load pretrained quantization parameters to :param model_name_or_path: Hugging Face stub or local folder containing a quantized model, which is used to load quantization parameters - :param load_weight_quantization: whether or not the weight quantization parameters shoud - be laoded + :param load_weight_quantization: whether or not the weight quantization parameters + should be loaded """ model_path = get_safetensors_folder(model_name_or_path) mapping = get_quantization_parameter_to_path_mapping(model_path) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b82a4195..645e3938 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -200,7 +200,8 @@ def _process_quantization( q_min, q_max = calculate_range(args, x.device) group_size = args.group_size - # blockwise FP8: quantize per 2D block, supports block_structure for static block quant + # blockwise FP8: quantize per 2D block, supports block_structure for static block + # quantization if args.strategy == QuantizationStrategy.BLOCK: original_shape = x.shape rows, cols = x.shape[-2], x.shape[-1] @@ -209,8 +210,8 @@ def _process_quantization( # Ensure exact division (tensor dimensions must be divisible by block size) if rows % block_height != 0: raise ValueError( - f"Tensor height {rows} is not divisible by block_height {block_height}. " - f"Block quantization requires exact division." + f"Tensor height {rows} is not divisible by block_height {block_height}." + f" Block quantization requires exact division." ) if cols % block_width != 0: raise ValueError( diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b0c32439..c9430e9e 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -17,7 +17,7 @@ import math import warnings from enum import Enum -from typing import List, Optional +from typing import Optional import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -87,7 +87,6 @@ def initialize_module_for_quantization( _initialize_attn_scales(module) else: - if scheme.input_activations is not None: _initialize_scale_zero_point( module, @@ -183,7 +182,8 @@ def _initialize_scale_zero_point( num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) expected_shape = (weight_shape[0], max(num_groups, 1)) elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only for weights + # For block quantization, scale shape should match number of blocks - only + # for weights if quantization_args.block_structure is None: raise ValueError( "Block quantization requires block_structure to be specified" @@ -196,9 +196,10 @@ def _initialize_scale_zero_point( # Warn if dimensions don't divide evenly if rows % block_height != 0 or cols % block_width != 0: warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide evenly " - f"by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization quality.", + f"Block quantization: tensor shape {weight_shape} does not divide" + f"evenly by block structure {quantization_args.block_structure}. " + f"Some blocks will be incomplete which may affect quantization" + "quality.", UserWarning, ) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 53dbf88e..8f478ef9 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -217,16 +217,18 @@ def validate_block_structure(cls, value) -> Optional[List[int]]: return [int(x) for x in value.split("x")] except Exception: raise ValueError( - f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + f"Invalid block_structure '{value}'. Must be a list of ints " + "[rows, cols]." ) if isinstance(value, (list, tuple)): if len(value) != 2 or not all(isinstance(v, int) for v in value): raise ValueError( - f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + f"Invalid block_structure '{value}'. Must be a list of ints " + "[rows, cols]." ) return list(value) raise ValueError( - f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + f"Invalid block_structure '{value}'. Must be a list of ints [rows, cols]." ) @field_validator("strategy", mode="before") @@ -307,7 +309,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": ) if strategy not in supported_strategies: raise ValueError( - f"One of {supported_strategies} must be used for dynamic quantization" + f"One of {supported_strategies} must be used for dynamic quant." ) if ( @@ -322,7 +324,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": observer != "memoryless" ): # avoid annoying users with old configs warnings.warn( - "No observer is used for dynamic quantization, setting to None" + "No observer is used for dynamic quant., setting to None" ) observer = None else: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 1124b55f..b1a17ed0 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional +from typing import List, Optional from compressed_tensors.quantization.quant_args import ( DynamicType, @@ -72,9 +72,10 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": ): warnings.warn( "Using GROUP strategy for both weights and input_activations " - f"with different group sizes ({weights.group_size} vs {inputs.group_size}) " - "may complicate fused kernel implementations. Consider using " - "TENSOR_GROUP strategy for both or matching group sizes.", + f"with different group sizes ({weights.group_size} vs " + f"{inputs.group_size}) may complicate fused kernel implementations. " + "Consider using TENSOR_GROUP strategy for both or matching group" + " sizes.", UserWarning, stacklevel=2, ) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 42a6e19e..989c9724 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -29,7 +29,6 @@ from compressed_tensors.utils import deprecated from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module -from tqdm import tqdm __all__ = [ diff --git a/src/compressed_tensors/registry/registry.py b/src/compressed_tensors/registry/registry.py index 39815a0b..3ad98d85 100644 --- a/src/compressed_tensors/registry/registry.py +++ b/src/compressed_tensors/registry/registry.py @@ -55,7 +55,7 @@ def standardize_lookup_name(name: str) -> str: def standardize_alias_name( - name: Union[None, str, List[str]] + name: Union[None, str, List[str]], ) -> Union[None, str, List[str]]: if name is None: return None diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 02ebd89b..a481497e 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, Union +from typing import Optional import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -26,7 +26,7 @@ from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("hadamard") diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 8b829451..4035d7a1 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -24,7 +24,7 @@ from compressed_tensors.utils import get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype -from torch.nn import Linear, Module, Parameter +from torch.nn import Module, Parameter @TransformFactory.register("random-matrix") diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index f353f8a2..92072857 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Tuple +from typing import Optional import torch from compressed_tensors.transform import TransformLocation diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 21ce4a0b..a7aafdab 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -110,7 +110,7 @@ def match_modules_set( Yields modules grouped with the same order and size as `targets`. Values are returned in order of `model.named_modules()` - For example, the following targets would yield module belonging to the following layers: + E.g. the following targets would yield module belonging to the following layers: ```python3 match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == ( ( diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index cf3707fd..55157b4a 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -282,7 +282,6 @@ def disable_hf_hook(module: torch.nn.Module): hooks = {} def collect_hooks(module): - nonlocal hooks if hasattr(module, "_hf_hook"): hooks[module] = module._hf_hook remove_hook_from_module(module) diff --git a/src/compressed_tensors/utils/safetensors_load.py b/src/compressed_tensors/utils/safetensors_load.py index 13a8eb48..cb2b913b 100644 --- a/src/compressed_tensors/utils/safetensors_load.py +++ b/src/compressed_tensors/utils/safetensors_load.py @@ -18,7 +18,6 @@ import struct from typing import Dict, Iterable, Optional, Tuple, Union -from safetensors import safe_open from torch import Tensor from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file diff --git a/tests/test_examples/test_bitmask_compression_ipynb.py b/tests/test_examples/test_bitmask_compression_ipynb.py index 12f196f6..a813dce8 100644 --- a/tests/test_examples/test_bitmask_compression_ipynb.py +++ b/tests/test_examples/test_bitmask_compression_ipynb.py @@ -16,7 +16,7 @@ nbformat = pytest.importorskip("nbformat") -from nbconvert.preprocessors import ExecutePreprocessor +from nbconvert.preprocessors import ExecutePreprocessor # noqa: E402 @pytest.mark.skip( diff --git a/tests/test_quantization/lifecycle/test_initialize.py b/tests/test_quantization/lifecycle/test_initialize.py index faf56ba6..e613e399 100644 --- a/tests/test_quantization/lifecycle/test_initialize.py +++ b/tests/test_quantization/lifecycle/test_initialize.py @@ -24,7 +24,6 @@ QuantizationScheme, QuantizationStatus, QuantizationStrategy, - QuantizationType, ) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, diff --git a/tests/test_transform/test_transform_config.py b/tests/test_transform/test_transform_config.py index 52167cfd..dc725743 100644 --- a/tests/test_transform/test_transform_config.py +++ b/tests/test_transform/test_transform_config.py @@ -66,6 +66,6 @@ def test_multiple_groups(): type="hadamard", apply=[linear_args_2], ) - config = TransformConfig( + _ = TransformConfig( config_groups={"transform_0": scheme_1, "transform_1": scheme_2} ) diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 705676b9..d168f725 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -70,43 +70,35 @@ class TestMatchName: def test_exact_match(self): """Test exact string matching""" - assert _match_name("layer1", "layer1") == True - assert _match_name("layer1", "layer2") == False - assert ( - _match_name( - "transformer.layers.0.self_attn.q_proj", - "transformer.layers.0.self_attn.q_proj", - ) - == True + assert _match_name("layer1", "layer1") + assert not _match_name("layer1", "layer2") + assert _match_name( + "transformer.layers.0.self_attn.q_proj", + "transformer.layers.0.self_attn.q_proj", ) def test_regex_match(self): """Test regex matching with "re:" prefix""" - assert _match_name("layer1", "re:layer.*") == True - assert _match_name("layer1", "re:^layer1$") == True - assert _match_name("layer1", "re:layer2") == False - assert ( - _match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") == True - ) - assert ( - _match_name( - "transformer.layers.0.self_attn.q_proj", - "re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$", - ) - == True + assert _match_name("layer1", "re:layer.*") + assert _match_name("layer1", "re:^layer1$") + assert not _match_name("layer1", "re:layer2") + assert _match_name("transformer.layers.0.self_attn.q_proj", "re:.*q_proj") + assert _match_name( + "transformer.layers.0.self_attn.q_proj", + "re:transformer\\.layers\\.\\d+\\.self_attn\\..*_proj$", ) def test_empty_strings(self): """Test edge cases with empty strings""" - assert _match_name("", "") == True - assert _match_name("layer1", "") == False - assert _match_name("", "layer1") == False + assert _match_name("", "") + assert not _match_name("layer1", "") + assert not _match_name("", "layer1") def test_regex_special_characters(self): """Test regex with special characters""" - assert _match_name("layer.1", "re:layer\\.1") == True - assert _match_name("layer.1", "re:layer.1") == True # . matches any char - assert _match_name("layer_1", "re:layer_1") == True + assert _match_name("layer.1", "re:layer\\.1") + assert _match_name("layer.1", "re:layer.1") # . matches any char + assert _match_name("layer_1", "re:layer_1") class TestMatchClass: @@ -115,32 +107,32 @@ class TestMatchClass: def test_direct_class_match(self): """Test matching direct class names""" linear = nn.Linear(10, 20) - assert _match_class(linear, "Linear") == True - assert _match_class(linear, "Conv2d") == False + assert _match_class(linear, "Linear") + assert not _match_class(linear, "Conv2d") norm = nn.LayerNorm(10) - assert _match_class(norm, "LayerNorm") == True - assert _match_class(norm, "BatchNorm1d") == False + assert _match_class(norm, "LayerNorm") + assert not _match_class(norm, "BatchNorm1d") def test_parent_class_match(self): """Test matching parent class names""" linear = nn.Linear(10, 20) - assert _match_class(linear, "Module") == True + assert _match_class(linear, "Module") conv = nn.Conv2d(3, 16, 3) - assert _match_class(conv, "Module") == True - assert _match_class(conv, "_ConvNd") == True + assert _match_class(conv, "Module") + assert _match_class(conv, "_ConvNd") def test_non_torch_module(self): """Test with non-torch modules""" regular_object = object() - assert _match_class(regular_object, "object") == False # not a torch.nn.Module + assert not _match_class(regular_object, "object") # not a torch.nn.Module def test_custom_module(self): """Test with custom module classes""" model = DummyModel() - assert _match_class(model, "DummyModel") == True - assert _match_class(model, "Module") == True + assert _match_class(model, "DummyModel") + assert _match_class(model, "Module") class TestIsMatch: @@ -149,27 +141,27 @@ class TestIsMatch: def test_name_match(self): """Test matching by name""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "layer1") == True - assert is_match("layer1", linear, "layer2") == False + assert is_match("layer1", linear, "layer1") + assert not is_match("layer1", linear, "layer2") def test_class_match(self): """Test matching by class""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "Linear") == True - assert is_match("layer1", linear, "Conv2d") == False + assert is_match("layer1", linear, "Linear") + assert not is_match("layer1", linear, "Conv2d") def test_combined_match(self): """Test that either name or class match works""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "layer1") == True # name match - assert is_match("layer1", linear, "Linear") == True # class match - assert is_match("layer1", linear, "layer2") == False # no match + assert is_match("layer1", linear, "layer1") # name match + assert is_match("layer1", linear, "Linear") # class match + assert not is_match("layer1", linear, "layer2") # no match def test_regex_in_name_match(self): """Test regex matching in name""" linear = nn.Linear(10, 20) - assert is_match("layer1", linear, "re:layer.*") == True - assert is_match("layer1", linear, "re:conv.*") == False + assert is_match("layer1", linear, "re:layer.*") + assert not is_match("layer1", linear, "re:conv.*") def test_internal_module_match(self): """Test not matching internal modules""" @@ -178,7 +170,7 @@ class InternalLinear(InternalModule, nn.Linear): pass linear = InternalLinear(10, 20) - assert is_match("layer1", linear, "re:layer.*") == False + assert not is_match("layer1", linear, "re:layer.*") class TestMatchNamedModules: @@ -364,7 +356,7 @@ def test_module_set_ordering(self): for module_set in matches: # Check that modules are returned in target order (v, q, k) v_proj, q_proj, k_proj = module_set - # We can"t easily check the exact modules, but we can check they"re all Linear + # We can't easily check the exact modules, but can check they're all Linear assert all(isinstance(m, nn.Linear) for m in [v_proj, q_proj, k_proj]) def test_incomplete_set_error(self): diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 1fce49b3..aed0186b 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -441,8 +441,8 @@ def test_delete_offload_module(exec_device): register_offload_module(model.linear, "child", child) delete_offload_module(model, "child") delete_offload_module(model.linear, "child") - assert not child in model.children() - assert not child in model.linear.children() + assert child not in model.children() + assert child not in model.linear.children() # with offloading model = ExampleModel() @@ -452,8 +452,8 @@ def test_delete_offload_module(exec_device): register_offload_module(model.linear, "child", child) delete_offload_module(model, "child") delete_offload_module(model.linear, "child") - assert not child in model.children() - assert not child in model.linear.children() + assert child not in model.children() + assert child not in model.linear.children() @requires_gpu diff --git a/utils/copyright.py b/utils/copyright.py index 7cc68dd0..433bccc3 100644 --- a/utils/copyright.py +++ b/utils/copyright.py @@ -142,6 +142,9 @@ def _get_files(patterns: List[str]) -> List[str]: def _dont_copyright(file_path: str) -> bool: + if file_path.endswith("compressed_tensors/version.py"): + return True + with open(file_path, "r") as file: content = file.read() @@ -343,4 +346,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()