Skip to content

Commit 2ecb124

Browse files
authored
Simplify apply_quantization_config (#433)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 15177be commit 2ecb124

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -703,9 +703,12 @@ def decompress(self, model_path: str, model: Module):
703703
with override_quantization_status(
704704
self.quantization_config, QuantizationStatus.FROZEN
705705
):
706-
names_to_scheme = apply_quantization_config(
707-
model, self.quantization_config
708-
)
706+
apply_quantization_config(model, self.quantization_config)
707+
names_to_scheme: Set[QuantizationScheme] = {
708+
name: getattr(module, "quantization_scheme")
709+
for name, module in model.named_modules()
710+
if getattr(module, "quantization_scheme", None) is not None
711+
}
709712
# Load activation scales/zp or any other quantization parameters
710713
# Conditionally load the weight quantization parameters if we have a
711714
# dense compressor or if a sparsity compressor has already been applied

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def load_pretrained_quantization_parameters(
115115

116116
def apply_quantization_config(
117117
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
118-
) -> Dict[str, QuantizationScheme]:
118+
):
119119
"""
120120
Initializes the model for quantization in-place based on the given config.
121121
Optionally coverts quantizable modules to compressed_linear modules
@@ -125,26 +125,22 @@ def apply_quantization_config(
125125
:param run_compressed: Whether the model will be run in compressed mode or
126126
decompressed fully on load
127127
"""
128-
# Workaround for when HF Quantizer passes None, see PR #180
129-
if config is None:
130-
return dict()
128+
from compressed_tensors.linear.compressed_linear import CompressedLinear
131129

132-
# remove reference to the original `config`
133-
# argument. This function can mutate it, and we'd
134-
# like to keep the original `config` as it is.
135130
config = deepcopy(config)
131+
if config is None: # see PR #180
132+
return dict()
133+
134+
# preprocess to support kv cache scheme
135+
config = process_quantization_config(config)
136+
136137
# build mapping of targets to schemes for easier matching
137138
# use ordered dict to preserve target ordering in config
138139
target_to_scheme = OrderedDict()
139-
config = process_quantization_config(config)
140-
names_to_scheme = dict()
141140
for scheme in config.config_groups.values():
142141
for target in scheme.targets:
143142
target_to_scheme[target] = scheme
144143

145-
if run_compressed:
146-
from compressed_tensors.linear.compressed_linear import CompressedLinear
147-
148144
# mark appropriate layers for quantization by setting their quantization schemes
149145
for name, submodule in match_named_modules(
150146
model, target_to_scheme, config.ignore, warn_on_fail=True
@@ -153,7 +149,12 @@ def apply_quantization_config(
153149
# quant scheme to the matching layers
154150
matched_targets = match_targets(name, submodule, target_to_scheme)
155151
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
156-
if run_compressed:
152+
# target matched - add layer and scheme to target list
153+
submodule.quantization_scheme = scheme
154+
155+
# replace with run compressed if applicable
156+
# FUTURE: move this to model compressor
157+
if isinstance(submodule, torch.nn.Linear) and run_compressed:
157158
format = config.format
158159
if format != CompressionFormat.dense.value:
159160
if isinstance(submodule, torch.nn.Linear):
@@ -165,14 +166,8 @@ def apply_quantization_config(
165166
)
166167
replace_module(model, name, compressed_linear)
167168

168-
# target matched - add layer and scheme to target list
169-
submodule.quantization_scheme = scheme
170-
171-
names_to_scheme[name] = submodule.quantization_scheme
172-
173169
# apply current quantization status across all targeted layers
174170
apply_quantization_status(model, config.quantization_status)
175-
return names_to_scheme
176171

177172

178173
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:

0 commit comments

Comments
 (0)