Skip to content

Commit fdfe6ce

Browse files
authored
Move files in compressed_tensors.compressor.utils compressed_tensors.utils to avoid name mangling (#102)
1 parent 4b790ec commit fdfe6ce

File tree

8 files changed

+41
-70
lines changed

8 files changed

+41
-70
lines changed

src/compressed_tensors/compressors/marlin_24.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
import numpy as np
1919
import torch
2020
from compressed_tensors.compressors import Compressor
21-
from compressed_tensors.compressors.utils import (
21+
from compressed_tensors.config import CompressionFormat
22+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
23+
from compressed_tensors.quantization.lifecycle.forward import quantize
24+
from compressed_tensors.utils import (
2225
get_permutations_24,
26+
is_quantization_param,
27+
merge_names,
2328
sparse_semi_structured_from_dense_cutlass,
2429
tensor_follows_mask_structure,
2530
)
26-
from compressed_tensors.config import CompressionFormat
27-
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
28-
from compressed_tensors.quantization.lifecycle.forward import quantize
29-
from compressed_tensors.utils import is_quantization_param, merge_names
3031
from torch import Tensor
3132
from tqdm import tqdm
3233

src/compressed_tensors/compressors/utils/__init__.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/compressed_tensors/compressors/utils/helpers.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

src/compressed_tensors/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
# limitations under the License.
1414
# flake8: noqa
1515

16+
from .helpers import *
17+
from .permutations_24 import *
1618
from .safetensors_load import *
19+
from .semi_structured_conversions import *

src/compressed_tensors/utils/helpers.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414

1515
from typing import Optional
1616

17+
import torch
1718
from transformers import AutoConfig
1819

1920

20-
__all__ = ["infer_compressor_from_model_config", "fix_fsdp_module_name"]
21+
__all__ = [
22+
"infer_compressor_from_model_config",
23+
"fix_fsdp_module_name",
24+
"tensor_follows_mask_structure",
25+
]
2126

2227
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
2328

@@ -60,3 +65,28 @@ def fix_fsdp_module_name(name: str) -> str:
6065
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
6166
"." + FSDP_WRAPPER_NAME, ""
6267
)
68+
69+
70+
def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
71+
"""
72+
:param tensor: tensor to check
73+
:param mask: mask structure to check for, in the format "n:m"
74+
:return: True if the tensor follows the mask structure, False otherwise.
75+
Note, some weights can incidentally be zero, so we check for
76+
atleast n zeros in each chunk of size m
77+
"""
78+
79+
n, m = tuple(map(int, mask.split(":")))
80+
# Reshape the tensor into chunks of size m
81+
tensor = tensor.view(-1, m)
82+
83+
# Count the number of zeros in each chunk
84+
zero_counts = (tensor == 0).sum(dim=1)
85+
86+
# Check if the number of zeros in each chunk atleast n
87+
# Greater than sign is needed as some weights can incidentally
88+
# be zero
89+
if not torch.all(zero_counts >= n).item():
90+
raise ValueError()
91+
92+
return True
File renamed without changes.
File renamed without changes.

tests/test_compressors/test_marlin_24.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Marlin24Compressor,
2222
map_modules_to_quant_args,
2323
)
24-
from compressed_tensors.compressors.utils import mask_creator
2524
from compressed_tensors.config import CompressionFormat
2625
from compressed_tensors.quantization import (
2726
QuantizationArgs,
@@ -32,7 +31,7 @@
3231
apply_quantization_config,
3332
apply_quantization_status,
3433
)
35-
from compressed_tensors.utils import merge_names
34+
from compressed_tensors.utils import mask_creator, merge_names
3635
from torch.nn.modules import Linear, Sequential
3736

3837

0 commit comments

Comments
 (0)