Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,12 @@ def decompress(self, model_path: str, model: Module):
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
apply_quantization_config(model, self.quantization_config)
names_to_scheme: Set[QuantizationScheme] = {
name: getattr(module, "quantization_scheme")
for name, module in model.named_modules()
if getattr(module, "quantization_scheme", None) is not None
}
# Load activation scales/zp or any other quantization parameters
# Conditionally load the weight quantization parameters if we have a
# dense compressor or if a sparsity compressor has already been applied
Expand Down
33 changes: 14 additions & 19 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def load_pretrained_quantization_parameters(

def apply_quantization_config(
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
) -> Dict[str, QuantizationScheme]:
):
"""
Initializes the model for quantization in-place based on the given config.
Optionally coverts quantizable modules to compressed_linear modules
Expand All @@ -125,26 +125,22 @@ def apply_quantization_config(
:param run_compressed: Whether the model will be run in compressed mode or
decompressed fully on load
"""
# Workaround for when HF Quantizer passes None, see PR #180
if config is None:
return dict()
from compressed_tensors.linear.compressed_linear import CompressedLinear

# remove reference to the original `config`
# argument. This function can mutate it, and we'd
# like to keep the original `config` as it is.
config = deepcopy(config)
if config is None: # see PR #180
return dict()

# preprocess to support kv cache scheme
config = process_quantization_config(config)

# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
config = process_quantization_config(config)
names_to_scheme = dict()
for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme

if run_compressed:
from compressed_tensors.linear.compressed_linear import CompressedLinear

# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in match_named_modules(
model, target_to_scheme, config.ignore, warn_on_fail=True
Expand All @@ -153,7 +149,12 @@ def apply_quantization_config(
# quant scheme to the matching layers
matched_targets = match_targets(name, submodule, target_to_scheme)
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
if run_compressed:
# target matched - add layer and scheme to target list
submodule.quantization_scheme = scheme

# replace with run compressed if applicable
# FUTURE: move this to model compressor
if isinstance(submodule, torch.nn.Linear) and run_compressed:
format = config.format
if format != CompressionFormat.dense.value:
if isinstance(submodule, torch.nn.Linear):
Expand All @@ -165,14 +166,8 @@ def apply_quantization_config(
)
replace_module(model, name, compressed_linear)

# target matched - add layer and scheme to target list
submodule.quantization_scheme = scheme

names_to_scheme[name] = submodule.quantization_scheme

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
return names_to_scheme


def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
Expand Down