Skip to content

Commit 72e5f3d

Browse files
remove apply_quantization_status
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 6ba47e5 commit 72e5f3d

File tree

3 files changed

+19
-61
lines changed

3 files changed

+19
-61
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
__all__ = [
4646
"load_pretrained_quantization_parameters",
4747
"apply_quantization_config",
48-
"apply_quantization_status",
4948
"find_name_or_class_matches",
5049
]
5150

@@ -163,8 +162,14 @@ def apply_quantization_config(
163162
)
164163
replace_module(model, name, compressed_linear)
165164

166-
# apply current quantization status to each targeted submodule
167-
apply_quantization_status(submodule, config.quantization_status)
165+
else:
166+
initialize_module_for_quantization(
167+
submodule,
168+
force_zero_point=config.quantization_status
169+
!= QuantizationStatus.COMPRESSED,
170+
)
171+
172+
submodule.quantization_status = config.quantization_status
168173

169174

170175
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
@@ -203,21 +208,6 @@ def process_kv_cache_config(
203208
return config
204209

205210

206-
def apply_quantization_status(module: Module, status: QuantizationStatus):
207-
"""
208-
Applies in place the quantization lifecycle up to the given status
209-
210-
:param module: module to apply quantization to
211-
:param status: status to update the module to
212-
"""
213-
214-
force_zero_point_init = status != QuantizationStatus.COMPRESSED
215-
216-
initialize_module_for_quantization(module, force_zero_point=force_zero_point_init)
217-
218-
module.quantization_status = status
219-
220-
221211
@deprecated(
222212
message="This function is deprecated and will be removed in a future release."
223213
"Please use `match_targets` from `compressed_tensors.utils.match` instead."

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ def initialize_module_for_quantization(
7575
:param force_zero_point: whether to force initialization of a zero point for
7676
symmetric quantization
7777
"""
78-
_clear_all_qparams(module)
79-
8078
# TODO: don't initialize parameters when running decompression
8179
scheme = scheme or getattr(module, "quantization_scheme", None)
8280
if scheme is None:
8381
# no scheme passed and layer not targeted for quantization - skip
8482
return
8583

84+
_clear_all_qparams(module)
85+
8686
if is_attention_module(module):
8787
# quantized actions based on calltime status
8888
_initialize_attn_scales(module)

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
QuantizationStatus,
2727
)
2828
from compressed_tensors.quantization.lifecycle import apply_quantization_config
29+
from compressed_tensors.utils import match_named_modules
2930
from tests.testing_utils import requires_accelerate
3031
from transformers import AutoModelForCausalLM
3132

@@ -103,60 +104,27 @@ def test_target_prioritization(mock_frozen):
103104

104105
def test_apply_quantization_config_tinyllama():
105106
quant_config = get_sample_tinyllama_quant_config(
106-
status=QuantizationStatus.CALIBRATION
107+
status=QuantizationStatus.INITIALIZED
107108
)
108109
model = get_tinyllama_model()
109110

110111
# check that model is not already quantized
111112
for module in model.modules():
112113
_test_layer_quantization_status(module, inputs=False, weights=False)
113114

114-
count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding")
115-
count_layer_num = defaultdict(int)
116-
117-
for name, module in model.named_modules():
118-
if name in quant_config.ignore:
119-
continue
120-
module_type = module.__class__.__name__
121-
if module_type in count_layer_names:
122-
count_layer_num[module_type] += 1
123-
124-
assert len(count_layer_num) > 0, f"None of {count_layer_names} found in model"
125-
assert all(value > 0 for value in count_layer_num.values())
126-
127115
# apply quant config to model
128116
apply_quantization_config(model, quant_config)
129117

130118
# check for correct application of quant config
131-
for name, module in model.named_modules():
132-
if name in quant_config.ignore:
133-
continue
134-
module_type = module.__class__.__name__
135-
if module_type in count_layer_names:
136-
count_layer_num[module_type] -= 1
137-
_inputs = module_type == "Linear"
138-
_weights = not module_type == "LlamaRotaryEmbedding"
139-
_test_layer_quantization_status(module, inputs=_inputs, weights=_weights)
140-
141-
assert all(
142-
value == 0 for value in count_layer_num.values()
143-
), "Not all values are zero"
144-
145-
# test quantization compression
146-
# sample forward pass to fill scales, zps
147-
model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int))
148-
quant_config.quantization_status = QuantizationStatus.COMPRESSED
149-
apply_quantization_config(model, quant_config)
150-
for name, module in model.named_modules():
151-
if name in quant_config.ignore:
152-
continue
153-
module_type = module.__class__.__name__
154-
if module_type == "Linear":
119+
for quant_scheme in quant_config.config_groups.values():
120+
for name, module in match_named_modules(
121+
model, quant_scheme.targets, quant_config.ignore
122+
):
155123
_test_layer_quantization_status(
156124
module,
157-
inputs=True,
158-
weights=True,
159-
expected_status=QuantizationStatus.COMPRESSED,
125+
inputs=quant_scheme.input_activations is not None,
126+
weights=quant_scheme.weights is not None,
127+
expected_status=QuantizationStatus.INITIALIZED,
160128
)
161129

162130

0 commit comments

Comments
 (0)