Skip to content

Commit 14a359f

Browse files
remove compress_quantized_weight, test fixes, remove sparseml references
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent fc2e102 commit 14a359f

File tree

5 files changed

+17
-25
lines changed

5 files changed

+17
-25
lines changed

examples/quantize_and_pack_int4.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@
144144
"outputs": [],
145145
"source": [
146146
"quantization_config_dict = {\n",
147-
"\t\"quant_method\": \"sparseml\",\n",
147+
"\t\"quant_method\": \"compressed-tensors\",\n",
148148
"\t\"format\": \"pack-quantized\",\n",
149149
"\t\"global_compression_ratio\": None,\n",
150150
"\t\"config_groups\": {\n",

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121

2222
import torch
2323
from compressed_tensors.config import CompressionFormat
24-
from compressed_tensors.quantization.lifecycle.compressed import (
25-
compress_quantized_weights,
26-
)
2724
from compressed_tensors.quantization.lifecycle.initialize import (
2825
initialize_module_for_quantization,
2926
)
@@ -219,20 +216,17 @@ def apply_quantization_status(module: Module, status: QuantizationStatus):
219216
# When decompressing, we set the scale_dtype as the model's dtype
220217
# This is because the normal workflow of using the weight's dtype
221218
# will be incorrect as the model weight will be compressed
222-
# Therfore, use the dtype set by the user using the PretrainedModel
219+
# Therefore, use the dtype set by the user using the PretrainedModel
223220
scale_dtype = None
224221
if status == QuantizationStatus.FROZEN:
225222
if hasattr(module, "dtype"):
226223
scale_dtype = module.dtype
227224

228-
module.apply(
229-
lambda module: initialize_module_for_quantization(
230-
module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype
231-
)
225+
initialize_module_for_quantization(
226+
module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype
232227
)
233228

234-
if status >= QuantizationStatus.COMPRESSED:
235-
module.apply(compress_quantized_weights)
229+
module.quantization_status = status
236230

237231

238232
@deprecated(

src/compressed_tensors/quantization/quant_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ class QuantizationConfig(BaseModel):
113113
:param config_groups: dict of QuantizationSchemes specifying the quantization
114114
settings for each quantized layer. A group could also be a reference to
115115
a predefined scheme name, mapped to a list of its target layers/classes
116-
:param quant_method: a constant used to differentiate sparseML quantization from
117-
other quantization configs
116+
:param quant_method: a constant used to differentiate compressed-tensors
117+
quantization from other quantization configs
118118
:param format: specifies how the quantized model is stored on disk
119119
:quantization_status: specifies the current status of all quantized layers. It is
120120
assumed all layers are in the same state.

src/compressed_tensors/utils/helpers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ def infer_compressor_from_model_config(
7171
return compressor
7272

7373

74-
# TODO: There is already the same function in
75-
# SparseML, should be moved to a shared location
76-
# in the future
7774
def fix_fsdp_module_name(name: str) -> str:
7875
"""
7976
Remove FSDP wrapper prefixes from a module name

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@
2525
QuantizationConfig,
2626
QuantizationStatus,
2727
)
28-
from compressed_tensors.quantization.lifecycle import (
29-
apply_quantization_config,
30-
apply_quantization_status,
31-
)
28+
from compressed_tensors.quantization.lifecycle import apply_quantization_config
3229
from tests.testing_utils import requires_accelerate
3330
from transformers import AutoModelForCausalLM
3431

@@ -105,7 +102,9 @@ def test_target_prioritization(mock_frozen):
105102

106103

107104
def test_apply_quantization_config_tinyllama():
108-
quant_config = get_sample_tinyllama_quant_config(status="calibration")
105+
quant_config = get_sample_tinyllama_quant_config(
106+
status=QuantizationStatus.CALIBRATION
107+
)
109108
model = get_tinyllama_model()
110109

111110
# check that model is not already quantized
@@ -146,7 +145,8 @@ def test_apply_quantization_config_tinyllama():
146145
# test quantization compression
147146
# sample forward pass to fill scales, zps
148147
model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int))
149-
apply_quantization_status(model, QuantizationStatus.COMPRESSED)
148+
quant_config.quantization_status = QuantizationStatus.COMPRESSED
149+
apply_quantization_config(model, quant_config)
150150
for name, module in model.named_modules():
151151
if name in quant_config.ignore:
152152
continue
@@ -157,7 +157,6 @@ def test_apply_quantization_config_tinyllama():
157157
inputs=True,
158158
weights=True,
159159
expected_status=QuantizationStatus.COMPRESSED,
160-
expected_dtype=torch.int8,
161160
)
162161

163162

@@ -218,7 +217,9 @@ def get_tinyllama_model():
218217
)
219218

220219

221-
def get_sample_tinyllama_quant_config(status: str = "frozen"):
220+
def get_sample_tinyllama_quant_config(
221+
status: QuantizationStatus = QuantizationStatus.FROZEN,
222+
):
222223
config_dict = {
223224
"quant_method": "compressed-tensors",
224225
"format": "fakequant",
@@ -270,7 +271,7 @@ def test_apply_quantization_status(caplog, target, should_raise_warning):
270271
# load a dense, unquantized tiny llama model
271272
model = get_tinyllama_model()
272273
quantization_config_dict = {
273-
"quant_method": "sparseml",
274+
"quant_method": "compressed-tensors",
274275
"format": "pack-quantized",
275276
"global_compression_ratio": None,
276277
"config_groups": {

0 commit comments

Comments
 (0)