Skip to content

Commit dfd069b

Browse files
[Multi-Modifier] Scoped apply quantization config (#432)
* squashed/rebased Signed-off-by: Brian Dellabetta <[email protected]> * cleanup Signed-off-by: Brian Dellabetta <[email protected]> * remove TODO Signed-off-by: Brian Dellabetta <[email protected]> * more clenaup Signed-off-by: Brian Dellabetta <[email protected]> * cleanup Signed-off-by: Brian Dellabetta <[email protected]> * formatting Signed-off-by: Brian Dellabetta <[email protected]> * formatting Signed-off-by: Brian Dellabetta <[email protected]> * resolve redundant merge code Signed-off-by: Brian Dellabetta <[email protected]> * style fixes Signed-off-by: Brian Dellabetta <[email protected]> * cleanup / test fixes Signed-off-by: Brian Dellabetta <[email protected]> * test fixes Signed-off-by: Brian Dellabetta <[email protected]> * formatting/touchups Signed-off-by: Brian Dellabetta <[email protected]> * stylefix Signed-off-by: Brian Dellabetta <[email protected]> * stylefixes Signed-off-by: Brian Dellabetta <[email protected]> * stylefixes Signed-off-by: Brian Dellabetta <[email protected]> * remaining test fixes Signed-off-by: Brian Dellabetta <[email protected]> * revert extraneous test change Signed-off-by: Brian Dellabetta <[email protected]> * remove test running code Signed-off-by: Brian Dellabetta <[email protected]> * remove infer_quantization_status Signed-off-by: Brian Dellabetta <[email protected]> * lifecycle updates for overwriting config Signed-off-by: Brian Dellabetta <[email protected]> * remove compress_quantized_weight, test fixes, remove sparseml references Signed-off-by: Brian Dellabetta <[email protected]> * drop frozen scale_dtype post-merge Signed-off-by: Brian Dellabetta <[email protected]> * formatting Signed-off-by: Brian Dellabetta <[email protected]> * clear previously initialized qparams Signed-off-by: Brian Dellabetta <[email protected]> * remove apply_quantization_status Signed-off-by: Brian Dellabetta <[email protected]> * stylefix Signed-off-by: Brian Dellabetta <[email protected]> * add ALL_QPARAM_KEYS var Signed-off-by: Brian Dellabetta <[email protected]> * multi-apply quantization config test Signed-off-by: Brian Dellabetta <[email protected]> * multi-apply test cleanup Signed-off-by: Brian Dellabetta <[email protected]> * ALL_QPARAM_NAMES Signed-off-by: Brian Dellabetta <[email protected]> * stylefixes Signed-off-by: Brian Dellabetta <[email protected]> * exclude sparsity param names Signed-off-by: Brian Dellabetta <[email protected]> * QuantizationMetadata class Signed-off-by: Brian Dellabetta <[email protected]> * stylefix Signed-off-by: Brian Dellabetta <[email protected]> * llm-compressor test fix Signed-off-by: Brian Dellabetta <[email protected]> --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 0e5df88 commit dfd069b

File tree

13 files changed

+245
-178
lines changed

13 files changed

+245
-178
lines changed

examples/quantize_and_pack_int4.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@
144144
"outputs": [],
145145
"source": [
146146
"quantization_config_dict = {\n",
147-
"\t\"quant_method\": \"sparseml\",\n",
147+
"\t\"quant_method\": \"compressed-tensors\",\n",
148148
"\t\"format\": \"pack-quantized\",\n",
149149
"\t\"global_compression_ratio\": None,\n",
150150
"\t\"config_groups\": {\n",

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import operator
1818
import os
1919
import re
20-
from contextlib import contextmanager
2120
from copy import deepcopy
2221
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
2322

@@ -50,6 +49,7 @@
5049
get_offloaded_device,
5150
get_safetensors_folder,
5251
has_offloaded_params,
52+
patch_attr,
5353
register_offload_parameter,
5454
update_parameter_data,
5555
)
@@ -200,9 +200,11 @@ def from_pretrained_model(
200200
sparsity_config=sparsity_config,
201201
quantization_config=quantization_config,
202202
transform_config=transform_config,
203-
compression_formats=[quantization_format]
204-
if isinstance(quantization_format, str)
205-
else quantization_format,
203+
compression_formats=(
204+
[quantization_format]
205+
if isinstance(quantization_format, str)
206+
else quantization_format
207+
),
206208
)
207209

208210
@staticmethod
@@ -594,8 +596,10 @@ def decompress(self, model_path: str, model: Module):
594596
# that the dtypes of the weights are not unintentionally updated.
595597
# The status is restored after quantization params are loaded.
596598

597-
with override_quantization_status(
598-
self.quantization_config, QuantizationStatus.FROZEN
599+
with patch_attr(
600+
self.quantization_config,
601+
"quantization_status",
602+
QuantizationStatus.FROZEN,
599603
):
600604
apply_quantization_config(model, self.quantization_config)
601605
names_to_scheme: Set[QuantizationScheme] = {
@@ -787,23 +791,3 @@ def new_dtype_byte_size(dtype):
787791
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
788792
bit_size = int(bit_search.groups()[0])
789793
return bit_size // 8
790-
791-
792-
@contextmanager
793-
def override_quantization_status(
794-
config: QuantizationConfig, status: QuantizationStatus
795-
):
796-
"""
797-
Within this context, the quantization status will be set to the
798-
supplied status. After the context exits, the original status
799-
will be restored.
800-
801-
:param config: the quantization config to override
802-
:param status: the status to temporarily set
803-
"""
804-
original_status = config.quantization_status
805-
config.quantization_status = status
806-
try:
807-
yield
808-
finally:
809-
config.quantization_status = original_status

src/compressed_tensors/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@
1717

1818
from .quant_args import *
1919
from .quant_config import *
20+
from .quant_metadata import *
2021
from .quant_scheme import *
2122
from .lifecycle import *

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 21 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121

2222
import torch
2323
from compressed_tensors.config import CompressionFormat
24-
from compressed_tensors.quantization.lifecycle.compressed import (
25-
compress_quantized_weights,
26-
)
2724
from compressed_tensors.quantization.lifecycle.initialize import (
2825
initialize_module_for_quantization,
2926
)
@@ -35,7 +32,6 @@
3532
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3633
from compressed_tensors.quantization.utils import (
3734
KV_CACHE_TARGETS,
38-
infer_quantization_status,
3935
is_kv_cache_quant_scheme,
4036
)
4137
from compressed_tensors.utils.helpers import deprecated, replace_module
@@ -49,7 +45,6 @@
4945
__all__ = [
5046
"load_pretrained_quantization_parameters",
5147
"apply_quantization_config",
52-
"apply_quantization_status",
5348
"find_name_or_class_matches",
5449
]
5550

@@ -154,20 +149,27 @@ def apply_quantization_config(
154149

155150
# replace with run compressed if applicable
156151
# FUTURE: move this to model compressor
157-
if isinstance(submodule, torch.nn.Linear) and run_compressed:
158-
format = config.format
159-
if format != CompressionFormat.dense.value:
160-
if isinstance(submodule, torch.nn.Linear):
161-
# TODO: expand to more module types
162-
compressed_linear = CompressedLinear.from_linear(
163-
submodule,
164-
quantization_scheme=scheme,
165-
quantization_format=format,
166-
)
167-
replace_module(model, name, compressed_linear)
168-
169-
# apply current quantization status across all targeted layers
170-
apply_quantization_status(model, config.quantization_status)
152+
if (
153+
run_compressed
154+
and isinstance(submodule, torch.nn.Linear)
155+
and config.format != CompressionFormat.dense.value
156+
):
157+
# TODO: expand to more module types
158+
compressed_linear = CompressedLinear.from_linear(
159+
submodule,
160+
quantization_scheme=scheme,
161+
quantization_format=config.format,
162+
)
163+
replace_module(model, name, compressed_linear)
164+
165+
else:
166+
initialize_module_for_quantization(
167+
submodule,
168+
force_zero_point=config.quantization_status
169+
!= QuantizationStatus.COMPRESSED,
170+
)
171+
172+
submodule.quantization_status = config.quantization_status
171173

172174

173175
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
@@ -206,29 +208,6 @@ def process_kv_cache_config(
206208
return config
207209

208210

209-
def apply_quantization_status(model: Module, status: QuantizationStatus):
210-
"""
211-
Applies in place the quantization lifecycle up to the given status
212-
213-
:param model: model to apply quantization to
214-
:param status: status to update the module to
215-
"""
216-
217-
current_status = infer_quantization_status(model)
218-
219-
if status >= QuantizationStatus.INITIALIZED > current_status:
220-
force_zero_point_init = status != QuantizationStatus.COMPRESSED
221-
222-
model.apply(
223-
lambda module: initialize_module_for_quantization(
224-
module, force_zero_point=force_zero_point_init
225-
)
226-
)
227-
228-
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
229-
model.apply(compress_quantized_weights)
230-
231-
232211
@deprecated(
233212
message="This function is deprecated and will be removed in a future release."
234213
"Please use `match_targets` from `compressed_tensors.utils.match` instead."
@@ -254,14 +233,6 @@ def find_name_or_class_matches(
254233
return match_targets(name, module, targets)
255234

256235

257-
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
258-
for module in model.modules():
259-
status = getattr(module, "quantization_status", None)
260-
if status is not None:
261-
return status
262-
return None
263-
264-
265236
def _load_quant_args_from_mapping(
266237
base_name: str, module_name: str, module: Module, mapping: Dict
267238
):

src/compressed_tensors/quantization/lifecycle/compressed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def compress_quantized_weights(module: Module):
4242
# no quantization scheme or weights not quantized, nothing to do
4343
return
4444

45-
if scheme is QuantizationStatus.COMPRESSED:
45+
status = getattr(module, "quantization_status", None)
46+
if status is QuantizationStatus.COMPRESSED:
4647
# module is already compressed, nothing to do
4748
return
4849

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,22 @@
1616
import logging
1717
import math
1818
import warnings
19-
from enum import Enum
2019
from typing import Optional
2120

2221
import torch
23-
from compressed_tensors.quantization.lifecycle.forward import (
24-
wrap_module_forward_quantized,
25-
)
26-
from compressed_tensors.quantization.quant_args import (
22+
from compressed_tensors.quantization import (
2723
FP8_E4M3_DATA,
2824
ActivationOrdering,
25+
KVCacheScaleType,
2926
QuantizationArgs,
27+
QuantizationMetadata,
28+
QuantizationScheme,
29+
QuantizationStatus,
3030
QuantizationStrategy,
3131
)
32-
from compressed_tensors.quantization.quant_config import QuantizationStatus
33-
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32+
from compressed_tensors.quantization.lifecycle.forward import (
33+
wrap_module_forward_quantized,
34+
)
3435
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
3536
from compressed_tensors.utils import (
3637
disable_hf_hook,
@@ -43,28 +44,23 @@
4344
__all__ = [
4445
"initialize_module_for_quantization",
4546
"is_attention_module",
46-
"KVCacheScaleType",
4747
]
4848

4949

5050
_LOGGER = logging.getLogger(__name__)
5151

5252

53-
class KVCacheScaleType(Enum):
54-
KEY = "k_scale"
55-
VALUE = "v_scale"
56-
57-
5853
def initialize_module_for_quantization(
5954
module: Module,
6055
scheme: Optional[QuantizationScheme] = None,
6156
force_zero_point: bool = True,
6257
):
6358
"""
64-
attaches appropriate scales, zero points, and observers to a layer
65-
given its target quantization scheme
59+
Attaches appropriate scales, zero points, and observers to a layer
60+
given its target quantization scheme.
6661
67-
apply to full model with `model.apply(initialize_module_for_quantization)`
62+
Previously initialized scales and zero points will be removed from
63+
module if they no longer apply to the scheme
6864
6965
:param module: module to set for calibration
7066
:param scheme: scheme to use for quantization. if None is provided,
@@ -79,6 +75,8 @@ def initialize_module_for_quantization(
7975
# no scheme passed and layer not targeted for quantization - skip
8076
return
8177

78+
QuantizationMetadata.clear_all_qparams(module)
79+
8280
if is_attention_module(module):
8381
# quantized actions based on calltime status
8482
_initialize_attn_scales(module)

src/compressed_tensors/quantization/quant_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ class QuantizationConfig(BaseModel):
113113
:param config_groups: dict of QuantizationSchemes specifying the quantization
114114
settings for each quantized layer. A group could also be a reference to
115115
a predefined scheme name, mapped to a list of its target layers/classes
116-
:param quant_method: a constant used to differentiate sparseML quantization from
117-
other quantization configs
116+
:param quant_method: a constant used to differentiate compressed-tensors
117+
quantization from other quantization configs
118118
:param format: specifies how the quantized model is stored on disk
119119
:quantization_status: specifies the current status of all quantized layers. It is
120120
assumed all layers are in the same state.
@@ -185,7 +185,8 @@ def from_pretrained(
185185
ignore[layer_type] = []
186186
ignore[layer_type].append(name)
187187
else:
188-
quantization_status = submodule.quantization_status
188+
if hasattr(submodule, "quantization_status"):
189+
quantization_status = submodule.quantization_status
189190
scheme = submodule.quantization_scheme
190191
quantization_type_names.add(layer_type)
191192

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from enum import Enum
16+
17+
from compressed_tensors.utils import delete_offload_parameter
18+
from torch.nn import Module
19+
20+
21+
__all__ = ["QuantizationMetadata", "KVCacheScaleType"]
22+
23+
24+
class KVCacheScaleType(Enum):
25+
KEY = "k_scale"
26+
VALUE = "v_scale"
27+
28+
29+
class QuantizationMetadata:
30+
"""
31+
Container class for metadata related to quantization
32+
"""
33+
34+
@staticmethod
35+
def all_qparam_names():
36+
"""
37+
All quantization parameter names that might be registered
38+
onto a module during lifecycle (excluding serialized parameters)
39+
"""
40+
return [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [
41+
f"{base_name}_{suffix}"
42+
for base_name in ("input", "weight", "output")
43+
for suffix in (
44+
"global_scale",
45+
"scale",
46+
"zero_point",
47+
"g_idx",
48+
)
49+
]
50+
51+
@classmethod
52+
def clear_all_qparams(cls, module: Module):
53+
"""
54+
Remove all parameters related to quantization that might have
55+
been registered onto a module previously in lifecycle (excluding
56+
serialized parameters)
57+
58+
:param module: Module to clear
59+
"""
60+
for key in cls.all_qparam_names():
61+
if hasattr(module, key):
62+
delete_offload_parameter(module, key)

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333

3434
__all__ = [
35-
"infer_quantization_status",
3635
"is_module_quantized",
3736
"is_model_quantized",
3837
"module_type",
@@ -234,21 +233,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
234233
return q_min, q_max
235234

236235

237-
def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
238-
"""
239-
Checks the quantization status of a model. Assumes all modules in the model have
240-
the same status, so only the first quantized model is checked.
241-
242-
:param model: model to check quantization status for
243-
:return: quantization status if the model is quantized, otherwise None
244-
"""
245-
for module in model.modules():
246-
status = getattr(module, "quantization_status", None)
247-
if status is not None:
248-
return status
249-
return None
250-
251-
252236
def is_module_quantized(module: Module) -> bool:
253237
"""
254238
Check if a module is quantized, based on the existence of a non-empty quantization

0 commit comments

Comments
 (0)