Skip to content

Commit 0679683

Browse files
authored
Add quality check to CI and fix existing errors (#408)
* Add `python-style` checks to pr ci Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Ignore auto-generated version.py file in copyright check Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Run `make style` Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Fix line lengths Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Remove unused imports Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Fix explicit comparison to True/False Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Fix misc flake8 errors Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Skip auto-generated `version.py` file when running isort Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Move quality check to separate workflow file and fix error msg Signed-off-by: Fynn Schmitt-Ulms <[email protected]> * Add 'release/*' branch triggers to ci quality and tests Signed-off-by: Fynn Schmitt-Ulms <[email protected]> --------- Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent d4354b0 commit 0679683

File tree

30 files changed

+153
-129
lines changed

30 files changed

+153
-129
lines changed

.github/workflows/quality-check.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Quality Checks
2+
on:
3+
push:
4+
branches:
5+
- main
6+
- 'release/*'
7+
pull_request:
8+
branches:
9+
- main
10+
- 'release/*'
11+
12+
jobs:
13+
quality-check:
14+
runs-on: ubuntu-24.04
15+
steps:
16+
- uses: actions/setup-python@v5
17+
with:
18+
python-version: '3.10'
19+
- uses: actions/checkout@v4
20+
with:
21+
fetch-depth: 0
22+
fetch-tags: true
23+
- name: Set Env
24+
run: |
25+
pip3 install --upgrade pip && pip3 install --upgrade setuptools
26+
- name: "⚙️ Install dependencies"
27+
run: pip3 install .[dev]
28+
- name: "🧹 Running quality checks"
29+
run: make quality

.github/workflows/test-check.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ on:
44
push:
55
branches:
66
- main
7+
- 'release/*'
78
pull_request:
89
branches:
910
- main
11+
- 'release/*'
1012

1113
jobs:
1214
python-tests:
@@ -26,3 +28,4 @@ jobs:
2628
run: pip3 install .[dev,accelerate]
2729
- name: "🔬 Running tests"
2830
run: make test
31+

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ensure_newline_before_comments = True
55
force_grid_wrap = 0
66
include_trailing_comma = True
77
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
8+
skip = src/compressed_tensors/version.py
89

910
line_length = 88
1011
lines_after_imports = 2

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
apply_quantization_config,
4343
load_pretrained_quantization_parameters,
4444
)
45-
from compressed_tensors.quantization.utils import is_module_quantized
4645
from compressed_tensors.transform import TransformConfig
4746
from compressed_tensors.utils import (
4847
align_module_device,
@@ -309,7 +308,7 @@ def __init__(
309308
if quantization_config is not None:
310309
# If a list of compression_format is not provided, we resolve the
311310
# relevant quantization formats using the config groups from the config
312-
# and if those are not defined, we fall-back to the global quantization format
311+
# and if those are not defined, we fall-back to the global quantization fmt
313312
if not self.compression_formats:
314313
self.compression_formats = self._fetch_unique_quantization_formats()
315314

@@ -661,11 +660,12 @@ def decompress(self, model_path: str, model: Module):
661660
:param model_path: path to compressed weights
662661
:param model: pytorch model to load decompressed weights into
663662
664-
Note: decompress makes use of both _replace_sparsity_weights and _replace_weights
665-
The variations in these methods are a result of the subtle variations between the sparsity
666-
and quantization compressors. Specifically, quantization compressors return not just the
667-
decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity
668-
compressors only return the decompressed weight.
663+
Note: decompress makes use of both _replace_sparsity_weights and
664+
_replace_weights. The variations in these methods are a result of the subtle
665+
variations between the sparsity and quantization compressors. Specifically,
666+
quantization compressors return not just the decompressed weight, but the
667+
quantization parameters (e.g scales, zero_point) whereas sparsity compressors
668+
only return the decompressed weight.
669669
670670
"""
671671
model_path = get_safetensors_folder(model_path)
@@ -707,13 +707,13 @@ def decompress(self, model_path: str, model: Module):
707707
model, self.quantization_config
708708
)
709709
# Load activation scales/zp or any other quantization parameters
710-
# Conditionally load the weight quantization parameters if we have a dense compressor
711-
# Or if a sparsity compressor has already been applied
710+
# Conditionally load the weight quantization parameters if we have a
711+
# dense compressor or if a sparsity compressor has already been applied
712712
load_pretrained_quantization_parameters(
713713
model,
714714
model_path,
715-
# TODO: all weight quantization params will be moved to the compressor in a follow-up
716-
# including initialization
715+
# TODO: all weight quantization params will be moved to the
716+
# compressor in a follow-up including initialization
717717
load_weight_quantization=(
718718
sparse_decompressed
719719
or isinstance(quant_compressor, DenseCompressor)
@@ -805,7 +805,6 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
805805
:param model: The model whose weights are to be updated.
806806
"""
807807
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
808-
809808
split_name = name.split(".")
810809
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
811810
module = operator.attrgetter(prefix)(model)
@@ -841,9 +840,10 @@ def _replace_weights(self, dense_weight_generator, model: Module):
841840
for param_name, param_data in data.items():
842841
if hasattr(module, param_name):
843842
# If compressed, will have an incorrect dtype for transformers >4.49
844-
# TODO: we can also just skip initialization of scales/zp if in decompression in init
845-
# to be consistent with loading which happens later as well
846-
# however, update_data does a good shape check - should be moved to the compressor
843+
# TODO: we can also just skip initialization of scales/zp if in
844+
# decompression in init to be consistent with loading which happens
845+
# later as well however, update_data does a good shape check -
846+
# should be moved to the compressor
847847
if param_name == "weight":
848848
delattr(module, param_name)
849849
requires_grad = param_data.dtype in (

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
get_nested_weight_mappings,
2525
merge_names,
2626
)
27-
from compressed_tensors.utils.safetensors_load import match_param_name
2827
from safetensors import safe_open
2928
from torch import Tensor
3029
from tqdm import tqdm
@@ -107,7 +106,8 @@ def compress(
107106
compressed_dict[name] = value.to(compression_device)
108107
continue
109108

110-
# compress values on meta if loading from meta otherwise on cpu (memory movement too expensive)
109+
# compress values on meta if loading from meta otherwise on cpu (memory
110+
# movement too expensive)
111111
module_path = prefix[:-1] if prefix.endswith(".") else prefix
112112
quant_args = names_to_scheme[module_path].weights
113113
compressed_values = self.compress_weight(

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from typing import Dict, Optional, Tuple
1717

18-
import numpy
1918
import torch
2019
from compressed_tensors.compressors.base import BaseCompressor
2120
from compressed_tensors.compressors.quantized_compressors.base import (
@@ -92,7 +91,6 @@ def compress_weight(
9291
zero_point: Optional[torch.Tensor] = None,
9392
g_idx: Optional[torch.Tensor] = None,
9493
) -> Dict[str, torch.Tensor]:
95-
9694
quantized_weight = quantize(
9795
x=weight,
9896
scale=scale,
@@ -112,7 +110,6 @@ def decompress_weight(
112110
compressed_data: Dict[str, Tensor],
113111
quantization_args: Optional[QuantizationArgs] = None,
114112
) -> torch.Tensor:
115-
116113
weight = compressed_data["weight_packed"]
117114
scale = compressed_data["weight_scale"]
118115
global_scale = compressed_data["weight_global_scale"]
@@ -175,14 +172,16 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
175172
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
176173
)
177174

175+
178176
# reference: : https://github.com/vllm-project/vllm/pull/16362
179177
def unpack_fp4_from_uint8(
180178
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
181179
) -> torch.Tensor:
182180
"""
183181
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
184-
(i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive
185-
fp4 value). The bits represent an index, which are mapped to an fp4 value.
182+
(i.e. first four bits correspond to one fp4 value, last four correspond to a
183+
consecutive fp4 value). The bits represent an index, which are mapped to an fp4
184+
value.
186185
187186
:param a: tensor to unpack
188187
:param m: original dim 0 size of the unpacked tensor

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import math
1515
from typing import Dict, Literal, Optional, Tuple, Union
1616

17-
import numpy as np
1817
import torch
1918
from compressed_tensors.compressors.base import BaseCompressor
2019
from compressed_tensors.compressors.quantized_compressors.base import (
@@ -135,7 +134,8 @@ def compress_weight(
135134
compressed_dict["weight_shape"] = weight_shape
136135
compressed_dict["weight_packed"] = packed_weight
137136

138-
# We typically don't compress zp; apart from when using the packed_compressor and when storing group/channel zp
137+
# We typically don't compress zp; apart from when using the packed_compressor
138+
# and when storing group/channel zp
139139
if not quantization_args.symmetric and quantization_args.strategy in [
140140
QuantizationStrategy.GROUP.value,
141141
QuantizationStrategy.CHANNEL.value,
@@ -166,7 +166,8 @@ def decompress_weight(
166166
num_bits = quantization_args.num_bits
167167
unpacked = unpack_from_int32(weight, num_bits, original_shape)
168168

169-
# NOTE: this will fail decompression as we don't currently handle packed zp on decompression
169+
# NOTE: this will fail decompression as we don't currently handle packed zp on
170+
# decompression
170171
if not quantization_args.symmetric and quantization_args.strategy in [
171172
QuantizationStrategy.GROUP.value,
172173
QuantizationStrategy.CHANNEL.value,

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Dict, Generator, List, Tuple, Union
16+
from typing import Dict, List, Tuple, Union
1717

1818
import torch
1919
from compressed_tensors.compressors.base import BaseCompressor

src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Marlin24Compressor(BaseCompressor):
4848

4949
@staticmethod
5050
def validate_quant_compatability(
51-
names_to_scheme: Dict[str, QuantizationScheme]
51+
names_to_scheme: Dict[str, QuantizationScheme],
5252
) -> bool:
5353
"""
5454
Checks if every quantized module in the model is compatible with Marlin24

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,14 @@ def load_pretrained_quantization_parameters(
7171
Loads the quantization parameters (scale and zero point) from model_name_or_path to
7272
a model that has already been initialized with a quantization config.
7373
74-
NOTE: Will always load inputs/output parameters.
75-
Will conditioanlly load weight parameters, if load_weight_quantization is set to True.
74+
NOTE: Will always load inputs/output parameters. Will conditioanlly load weight
75+
parameters, if load_weight_quantization is set to True.
7676
7777
:param model: model to load pretrained quantization parameters to
7878
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
7979
model, which is used to load quantization parameters
80-
:param load_weight_quantization: whether or not the weight quantization parameters shoud
81-
be laoded
80+
:param load_weight_quantization: whether or not the weight quantization parameters
81+
should be loaded
8282
"""
8383
model_path = get_safetensors_folder(model_name_or_path)
8484
mapping = get_quantization_parameter_to_path_mapping(model_path)
@@ -261,7 +261,8 @@ def find_name_or_class_matches(
261261
"""
262262
if check_contains:
263263
raise NotImplementedError(
264-
"This function is deprecated, and the check_contains=True option has been removed."
264+
"This function is deprecated, and the check_contains=True option has been"
265+
" removed."
265266
)
266267

267268
return match_targets(name, module, targets)

0 commit comments

Comments
 (0)