Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -264,23 +264,18 @@ def parse_quantization_config(

return quantization_config

def _fetch_unique_quantization_formats(self) -> List[str]:
def _fetch_unique_quantization_formats(self) -> List[Optional[str]]:
"""
Get all unique compression formats present in a model.
:return: list of quantization formats
"""
quantization_formats = []
for _, scheme in self.quantization_config.config_groups.items():
if scheme.format is not None and scheme.format not in quantization_formats:
quantization_formats.append(scheme.format)
quantization_formats = set(
scheme.format for scheme in self.quantization_config.config_groups.values()
)
quantization_formats.add(self.quantization_config.format)

if (
len(quantization_formats) == 0
and self.quantization_config.format
!= CompressionFormat.mixed_precision.value
):
quantization_formats.append(self.quantization_config.format)
return quantization_formats
quantization_formats -= {CompressionFormat.mixed_precision.value, None}
return list(quantization_formats)

def __init__(
self,
Expand Down Expand Up @@ -314,6 +309,9 @@ def __init__(

self.quantization_compressor = {}
for format in self.compression_formats:
if format is None:
format = CompressionFormat.dense.value

self.quantization_compressor[
format
] = BaseCompressor.load_from_registry(
Expand Down Expand Up @@ -703,9 +701,12 @@ def decompress(self, model_path: str, model: Module):
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
apply_quantization_config(model, self.quantization_config)
names_to_scheme: Set[QuantizationScheme] = {
name: getattr(module, "quantization_scheme")
for name, module in model.named_modules()
if getattr(module, "quantization_scheme", None) is not None
}
# Load activation scales/zp or any other quantization parameters
# Conditionally load the weight quantization parameters if we have a
# dense compressor or if a sparsity compressor has already been applied
Expand Down Expand Up @@ -811,6 +812,8 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):

params_device = next(module.parameters()).device
device = "cpu" if has_offloaded_params(module) else params_device
if not hasattr(module, param_name):
breakpoint()
delattr(module, param_name)
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
Expand Down
167 changes: 48 additions & 119 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
# limitations under the License.

import logging
from collections import OrderedDict
from copy import deepcopy
from typing import Dict, Iterable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Union
from typing import Dict, Iterable, List, Optional, Union

import torch
from compressed_tensors.config import CompressionFormat
Expand All @@ -36,20 +33,22 @@
from compressed_tensors.quantization.utils import (
KV_CACHE_TARGETS,
infer_quantization_status,
is_kv_cache_quant_scheme,
)
from compressed_tensors.utils.helpers import deprecated, replace_module
from compressed_tensors.utils.match import match_named_modules, match_targets
from compressed_tensors.utils.offload import update_parameter_data
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
from safetensors import safe_open
from torch.nn import Module
from transformers import PreTrainedModel


__all__ = [
"load_pretrained_quantization_parameters",
"apply_quantization_config",
"apply_quantization_status",
"attach_scheme",
"attach_config",
"find_name_or_class_matches",
]

Expand Down Expand Up @@ -114,8 +113,10 @@ def load_pretrained_quantization_parameters(


def apply_quantization_config(
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
) -> Dict[str, QuantizationScheme]:
model: PreTrainedModel,
config: Union[QuantizationConfig, None],
run_compressed: bool = False,
):
"""
Initializes the model for quantization in-place based on the given config.
Optionally coverts quantizable modules to compressed_linear modules
Expand All @@ -125,54 +126,54 @@ def apply_quantization_config(
:param run_compressed: Whether the model will be run in compressed mode or
decompressed fully on load
"""
# Workaround for when HF Quantizer passes None, see PR #180
if config is None:
return dict()
from compressed_tensors.linear.compressed_linear import CompressedLinear

# remove reference to the original `config`
# argument. This function can mutate it, and we'd
# like to keep the original `config` as it is.
config = deepcopy(config)
# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
config = process_quantization_config(config)
names_to_scheme = dict()
for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme
if config is None: # see PR #180
return dict()

if run_compressed:
from compressed_tensors.linear.compressed_linear import CompressedLinear
# preprocess to support kv cache scheme
config = process_quantization_config(config)

# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in match_named_modules(
model, target_to_scheme, config.ignore, warn_on_fail=True
):
# mark modules to be quantized by adding
# quant scheme to the matching layers
matched_targets = match_targets(name, submodule, target_to_scheme)
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
if run_compressed:
format = config.format
if format != CompressionFormat.dense.value:
if isinstance(submodule, torch.nn.Linear):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=format,
)
replace_module(model, name, compressed_linear)

# target matched - add layer and scheme to target list
submodule.quantization_scheme = scheme

names_to_scheme[name] = submodule.quantization_scheme
for scheme in config.config_groups.values():
for name, submodule in match_named_modules(
model, scheme.targets, config.ignore or [], warn_on_fail=True
):
# attach scheme to module (with merging)
attach_scheme(submodule, scheme)

# replace with run compressed if applicable
# FUTURE: move this to model compressor
if isinstance(submodule, torch.nn.Linear) and run_compressed:
format = config.format
if format != CompressionFormat.dense.value:
if isinstance(submodule, torch.nn.Linear):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=format,
)
replace_module(model, name, compressed_linear)

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
return names_to_scheme

# attach config for serialization
attach_config(model, config)


def attach_scheme(module: Module, scheme: QuantizationScheme):
if existing_scheme := getattr(module, "quantization_scheme", None):
scheme = scheme.merge(existing_scheme)
setattr(module, "quantization_scheme", scheme)


def attach_config(model: PreTrainedModel, config: QuantizationConfig):
if existing_config := getattr(model, "quantization_config", None):
config = config.merge(existing_config)
setattr(model, "quantization_config", config)


def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
Expand Down Expand Up @@ -268,14 +269,6 @@ def find_name_or_class_matches(
return match_targets(name, module, targets)


def _infer_status(model: Module) -> Optional[QuantizationStatus]:
for module in model.modules():
status = getattr(module, "quantization_status", None)
if status is not None:
return status
return None


def _load_quant_args_from_mapping(
base_name: str, module_name: str, module: Module, mapping: Dict
):
Expand Down Expand Up @@ -318,67 +311,3 @@ def _load_quant_args_from_mapping(
state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}")

update_parameter_data(module, state_dict_zp, zp_name)


def _scheme_from_targets(
target_to_scheme: OrderedDictType[str, QuantizationScheme],
targets: List[str],
name: str,
) -> QuantizationScheme:
if len(targets) == 1:
# if `targets` iterable contains a single element
# use it as the key
return target_to_scheme[targets[0]]

# otherwise, we need to merge QuantizationSchemes corresponding
# to multiple targets. This is most likely because `name` module
# is being target both as an ordinary quantization target, as well
# as kv cache quantization target
schemes_to_merge = [target_to_scheme[target] for target in targets]
return _merge_schemes(schemes_to_merge, name)


def _merge_schemes(
schemes_to_merge: List[QuantizationScheme], name: str
) -> QuantizationScheme:
kv_cache_quantization_scheme = [
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
]
if not kv_cache_quantization_scheme:
# if the schemes_to_merge do not contain any
# kv cache QuantizationScheme
# return the first scheme (the prioritized one,
# since the order of schemes_to_merge matters)
return schemes_to_merge[0]
else:
# fetch the kv cache QuantizationScheme and the highest
# priority non-kv cache QuantizationScheme and merge them
kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
quantization_scheme = [
scheme
for scheme in schemes_to_merge
if not is_kv_cache_quant_scheme(scheme)
][0]
schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
merged_scheme = {}
for scheme in schemes_to_merge:
scheme_dict = {
k: v for k, v in scheme.model_dump().items() if v is not None
}
# when merging multiple schemes, the final target will be
# the `name` argument - hence erase the original targets
del scheme_dict["targets"]
# make sure that schemes do not "clash" with each other
overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
if overlapping_keys:
raise ValueError(
f"The module: {name} is being modified by two clashing "
f"quantization schemes, that jointly try to override "
f"properties: {overlapping_keys}. Fix the quantization config "
"so that it is not ambiguous."
)
merged_scheme.update(scheme_dict)

merged_scheme.update(targets=[name])

return QuantizationScheme(**merged_scheme)
Loading