diff --git a/examples/quantize_and_pack_int4.ipynb b/examples/quantize_and_pack_int4.ipynb index 8cd58f2f2..e4d654685 100644 --- a/examples/quantize_and_pack_int4.ipynb +++ b/examples/quantize_and_pack_int4.ipynb @@ -144,7 +144,7 @@ "outputs": [], "source": [ "quantization_config_dict = {\n", - "\t\"quant_method\": \"sparseml\",\n", + "\t\"quant_method\": \"compressed-tensors\",\n", "\t\"format\": \"pack-quantized\",\n", "\t\"global_compression_ratio\": None,\n", "\t\"config_groups\": {\n", diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 8896c060d..d8fc42078 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -17,7 +17,6 @@ import operator import os import re -from contextlib import contextmanager from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union @@ -50,6 +49,7 @@ get_offloaded_device, get_safetensors_folder, has_offloaded_params, + patch_attr, register_offload_parameter, update_parameter_data, ) @@ -200,9 +200,11 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, - compression_formats=[quantization_format] - if isinstance(quantization_format, str) - else quantization_format, + compression_formats=( + [quantization_format] + if isinstance(quantization_format, str) + else quantization_format + ), ) @staticmethod @@ -594,8 +596,10 @@ def decompress(self, model_path: str, model: Module): # that the dtypes of the weights are not unintentionally updated. # The status is restored after quantization params are loaded. - with override_quantization_status( - self.quantization_config, QuantizationStatus.FROZEN + with patch_attr( + self.quantization_config, + "quantization_status", + QuantizationStatus.FROZEN, ): apply_quantization_config(model, self.quantization_config) names_to_scheme: Set[QuantizationScheme] = { @@ -787,23 +791,3 @@ def new_dtype_byte_size(dtype): raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) return bit_size // 8 - - -@contextmanager -def override_quantization_status( - config: QuantizationConfig, status: QuantizationStatus -): - """ - Within this context, the quantization status will be set to the - supplied status. After the context exits, the original status - will be restored. - - :param config: the quantization config to override - :param status: the status to temporarily set - """ - original_status = config.quantization_status - config.quantization_status = status - try: - yield - finally: - config.quantization_status = original_status diff --git a/src/compressed_tensors/quantization/__init__.py b/src/compressed_tensors/quantization/__init__.py index 9fde69a35..04ccedf53 100644 --- a/src/compressed_tensors/quantization/__init__.py +++ b/src/compressed_tensors/quantization/__init__.py @@ -17,5 +17,6 @@ from .quant_args import * from .quant_config import * +from .quant_metadata import * from .quant_scheme import * from .lifecycle import * diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index faa48df20..89ac7a887 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -21,9 +21,6 @@ import torch from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization.lifecycle.compressed import ( - compress_quantized_weights, -) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -35,7 +32,6 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( KV_CACHE_TARGETS, - infer_quantization_status, is_kv_cache_quant_scheme, ) from compressed_tensors.utils.helpers import deprecated, replace_module @@ -49,7 +45,6 @@ __all__ = [ "load_pretrained_quantization_parameters", "apply_quantization_config", - "apply_quantization_status", "find_name_or_class_matches", ] @@ -154,20 +149,27 @@ def apply_quantization_config( # replace with run compressed if applicable # FUTURE: move this to model compressor - if isinstance(submodule, torch.nn.Linear) and run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # apply current quantization status across all targeted layers - apply_quantization_status(model, config.quantization_status) + if ( + run_compressed + and isinstance(submodule, torch.nn.Linear) + and config.format != CompressionFormat.dense.value + ): + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=config.format, + ) + replace_module(model, name, compressed_linear) + + else: + initialize_module_for_quantization( + submodule, + force_zero_point=config.quantization_status + != QuantizationStatus.COMPRESSED, + ) + + submodule.quantization_status = config.quantization_status def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: @@ -206,29 +208,6 @@ def process_kv_cache_config( return config -def apply_quantization_status(model: Module, status: QuantizationStatus): - """ - Applies in place the quantization lifecycle up to the given status - - :param model: model to apply quantization to - :param status: status to update the module to - """ - - current_status = infer_quantization_status(model) - - if status >= QuantizationStatus.INITIALIZED > current_status: - force_zero_point_init = status != QuantizationStatus.COMPRESSED - - model.apply( - lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init - ) - ) - - if current_status < status >= QuantizationStatus.COMPRESSED > current_status: - model.apply(compress_quantized_weights) - - @deprecated( message="This function is deprecated and will be removed in a future release." "Please use `match_targets` from `compressed_tensors.utils.match` instead." @@ -254,14 +233,6 @@ def find_name_or_class_matches( return match_targets(name, module, targets) -def _infer_status(model: Module) -> Optional[QuantizationStatus]: - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def _load_quant_args_from_mapping( base_name: str, module_name: str, module: Module, mapping: Dict ): diff --git a/src/compressed_tensors/quantization/lifecycle/compressed.py b/src/compressed_tensors/quantization/lifecycle/compressed.py index 00f707920..ee717e399 100644 --- a/src/compressed_tensors/quantization/lifecycle/compressed.py +++ b/src/compressed_tensors/quantization/lifecycle/compressed.py @@ -42,7 +42,8 @@ def compress_quantized_weights(module: Module): # no quantization scheme or weights not quantized, nothing to do return - if scheme is QuantizationStatus.COMPRESSED: + status = getattr(module, "quantization_status", None) + if status is QuantizationStatus.COMPRESSED: # module is already compressed, nothing to do return diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5350b4a2c..9f852c74f 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -16,21 +16,22 @@ import logging import math import warnings -from enum import Enum from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from compressed_tensors.quantization.quant_args import ( +from compressed_tensors.quantization import ( FP8_E4M3_DATA, ActivationOrdering, + KVCacheScaleType, QuantizationArgs, + QuantizationMetadata, + QuantizationScheme, + QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme from compressed_tensors.utils import ( disable_hf_hook, @@ -43,28 +44,23 @@ __all__ = [ "initialize_module_for_quantization", "is_attention_module", - "KVCacheScaleType", ] _LOGGER = logging.getLogger(__name__) -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, ): """ - attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme + Attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme. - apply to full model with `model.apply(initialize_module_for_quantization)` + Previously initialized scales and zero points will be removed from + module if they no longer apply to the scheme :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -79,6 +75,8 @@ def initialize_module_for_quantization( # no scheme passed and layer not targeted for quantization - skip return + QuantizationMetadata.clear_all_qparams(module) + if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 42df3a337..4478a2ae5 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -113,8 +113,8 @@ class QuantizationConfig(BaseModel): :param config_groups: dict of QuantizationSchemes specifying the quantization settings for each quantized layer. A group could also be a reference to a predefined scheme name, mapped to a list of its target layers/classes - :param quant_method: a constant used to differentiate sparseML quantization from - other quantization configs + :param quant_method: a constant used to differentiate compressed-tensors + quantization from other quantization configs :param format: specifies how the quantized model is stored on disk :quantization_status: specifies the current status of all quantized layers. It is assumed all layers are in the same state. @@ -185,7 +185,8 @@ def from_pretrained( ignore[layer_type] = [] ignore[layer_type].append(name) else: - quantization_status = submodule.quantization_status + if hasattr(submodule, "quantization_status"): + quantization_status = submodule.quantization_status scheme = submodule.quantization_scheme quantization_type_names.add(layer_type) diff --git a/src/compressed_tensors/quantization/quant_metadata.py b/src/compressed_tensors/quantization/quant_metadata.py new file mode 100644 index 000000000..e7567eabe --- /dev/null +++ b/src/compressed_tensors/quantization/quant_metadata.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from compressed_tensors.utils import delete_offload_parameter +from torch.nn import Module + + +__all__ = ["QuantizationMetadata", "KVCacheScaleType"] + + +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +class QuantizationMetadata: + """ + Container class for metadata related to quantization + """ + + @staticmethod + def all_qparam_names(): + """ + All quantization parameter names that might be registered + onto a module during lifecycle (excluding serialized parameters) + """ + return [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [ + f"{base_name}_{suffix}" + for base_name in ("input", "weight", "output") + for suffix in ( + "global_scale", + "scale", + "zero_point", + "g_idx", + ) + ] + + @classmethod + def clear_all_qparams(cls, module: Module): + """ + Remove all parameters related to quantization that might have + been registered onto a module previously in lifecycle (excluding + serialized parameters) + + :param module: Module to clear + """ + for key in cls.all_qparam_names(): + if hasattr(module, key): + delete_offload_parameter(module, key) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d545193..1b6937d47 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -32,7 +32,6 @@ __all__ = [ - "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -234,21 +233,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max -def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa - """ - Checks the quantization status of a model. Assumes all modules in the model have - the same status, so only the first quantized model is checked. - - :param model: model to check quantization status for - :return: quantization status if the model is quantized, otherwise None - """ - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 44adc7ef2..bdaa40c05 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -74,9 +74,6 @@ def infer_compressor_from_model_config( return compressor -# TODO: There is already the same function in -# SparseML, should be moved to a shared location -# in the future def fix_fsdp_module_name(name: str) -> str: """ Remove FSDP wrapper prefixes from a module name diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 11e2a2a1c..b96e83d04 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -136,12 +136,12 @@ def match_targets( if isinstance(module, InternalModule): return [] - # The order of the output `matches` list matters, the are arranged from most + # The order of the output `matches` list matters, they are arranged from most # specific to least specific, and this order will be used when merging configs. # The entries are sorted in the following order: # 1. matches on exact strings # 2. matches on regex patterns - # 3. matches on module names + # 3. matches on module names (e.g. "Linear") targets = sorted(targets, key=lambda x: ("re:" in x, x)) matched_targets = [] diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 115cf3f5a..5a9a1762f 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -118,16 +118,14 @@ def __init__(self, weights, weight_scale=None, weight_zero_point=None): self.linear = nn.Linear(in_features, out_features, bias=False) # Set the weights of the linear layer - self.linear.weight = nn.Parameter(weights, requires_grad=False) + self.linear.weight = nn.Parameter(weights.detach().clone()) # Attach weight_scale and weight_zero_point as parameters if weight_scale is not None: - self.linear.weight_scale = nn.Parameter( - torch.tensor(weight_scale), requires_grad=False - ) + self.linear.weight_scale = nn.Parameter(weight_scale.detach().clone()) if weight_zero_point is not None: self.linear.weight_zero_point = nn.Parameter( - torch.tensor(weight_zero_point), requires_grad=False + weight_zero_point.detach().clone() ) def forward(self, x): @@ -388,9 +386,7 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir): ) def test_compress_model_meta(model_stub, q_format, s_config): # Load model on CPU to get expected compressed state_dict - cpu_model = AutoModelForCausalLM.from_pretrained( - model_stub, torch_dtype=torch.float32 - ) + cpu_model = AutoModelForCausalLM.from_pretrained(model_stub) reference_compressor = ModelCompressor.from_pretrained_model( cpu_model, s_config, [q_format] ) @@ -400,7 +396,6 @@ def test_compress_model_meta(model_stub, q_format, s_config): # Load model on meta device meta_model = AutoModelForCausalLM.from_pretrained( model_stub, - torch_dtype=torch.float32, low_cpu_mem_usage=True, ) for module in meta_model.modules(): @@ -511,8 +506,12 @@ def test_decompress_model(model_stub, comp_stub): # equivalent to decompressing from disk assert decompressed.keys() == true_decompressed.keys() for key in decompressed.keys(): - assert decompressed[key].dtype == true_decompressed[key].dtype - assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}" + assert ( + decompressed[key].dtype == true_decompressed[key].dtype + ), f"{key} dtypes not equal" + assert torch.all( + decompressed[key] == true_decompressed[key] + ), f"{key} values not equal" def remove_empty_weight_zero_points(state_dict): diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index d5fd6c2cd..ae8908202 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -13,7 +13,6 @@ # limitations under the License. import re -from collections import defaultdict from typing import Optional from unittest.mock import MagicMock @@ -22,13 +21,15 @@ from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, + QuantizationArgs, QuantizationConfig, + QuantizationScheme, QuantizationStatus, + QuantizationStrategy, + QuantizationType, ) -from compressed_tensors.quantization.lifecycle import ( - apply_quantization_config, - apply_quantization_status, -) +from compressed_tensors.quantization.lifecycle import apply_quantization_config +from compressed_tensors.utils import is_match, match_named_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -105,59 +106,28 @@ def test_target_prioritization(mock_frozen): def test_apply_quantization_config_tinyllama(): - quant_config = get_sample_tinyllama_quant_config(status="calibration") + quant_config = get_sample_tinyllama_quant_config( + status=QuantizationStatus.INITIALIZED + ) model = get_tinyllama_model() # check that model is not already quantized for module in model.modules(): _test_layer_quantization_status(module, inputs=False, weights=False) - count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding") - count_layer_num = defaultdict(int) - - for name, module in model.named_modules(): - if name in quant_config.ignore: - continue - module_type = module.__class__.__name__ - if module_type in count_layer_names: - count_layer_num[module_type] += 1 - - assert len(count_layer_num) > 0, f"None of {count_layer_names} found in model" - assert all(value > 0 for value in count_layer_num.values()) - # apply quant config to model apply_quantization_config(model, quant_config) # check for correct application of quant config - for name, module in model.named_modules(): - if name in quant_config.ignore: - continue - module_type = module.__class__.__name__ - if module_type in count_layer_names: - count_layer_num[module_type] -= 1 - _inputs = module_type == "Linear" - _weights = not module_type == "LlamaRotaryEmbedding" - _test_layer_quantization_status(module, inputs=_inputs, weights=_weights) - - assert all( - value == 0 for value in count_layer_num.values() - ), "Not all values are zero" - - # test quantization compression - # sample forward pass to fill scales, zps - model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) - apply_quantization_status(model, QuantizationStatus.COMPRESSED) - for name, module in model.named_modules(): - if name in quant_config.ignore: - continue - module_type = module.__class__.__name__ - if module_type == "Linear": + for quant_scheme in quant_config.config_groups.values(): + for name, module in match_named_modules( + model, quant_scheme.targets, quant_config.ignore + ): _test_layer_quantization_status( module, - inputs=True, - weights=True, - expected_status=QuantizationStatus.COMPRESSED, - expected_dtype=torch.int8, + inputs=quant_scheme.input_activations is not None, + weights=quant_scheme.weights is not None, + expected_status=QuantizationStatus.INITIALIZED, ) @@ -218,7 +188,9 @@ def get_tinyllama_model(): ) -def get_sample_tinyllama_quant_config(status: str = "frozen"): +def get_sample_tinyllama_quant_config( + status: QuantizationStatus = QuantizationStatus.FROZEN, +): config_dict = { "quant_method": "compressed-tensors", "format": "fakequant", @@ -264,13 +236,13 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): [("Linear", "re:.*foobarbaz"), True], ], ) -def test_apply_quantization_status(caplog, target, should_raise_warning): +def test_apply_quantization_config(caplog, target, should_raise_warning): import logging # load a dense, unquantized tiny llama model model = get_tinyllama_model() quantization_config_dict = { - "quant_method": "sparseml", + "quant_method": "compressed-tensors", "format": "pack-quantized", "global_compression_ratio": None, "config_groups": { @@ -297,3 +269,100 @@ def test_apply_quantization_status(caplog, target, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 + + +def test_multi_apply_quantization_config(): + """ + Ensure that multiple quantization configs are applied correctly + If quantization config was previously applied to a module, + those changes should be reset for newly applied quantization config + """ + model = get_tinyllama_model() + + # FP8 applied to self_attn + qconfig1 = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=[ + r"re:.*self_attn\.(k|q|o|v)_proj$", + ], + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + ) + }, + ignore=["lm_head"], + ) + # W4A16_ASYM applied to mlp and self_attn.o_proj to validate overwriting + qconfig2 = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=[ + r"re:.*mlp\.(down|gate|up)_proj$", + r"re:.*self_attn\.o_proj$", + ], + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.GROUP, + group_size=128, + symmetric=False, + dynamic=False, + ), + ) + }, + ignore=["lm_head"], + ) + + apply_quantization_config(model, qconfig1) + apply_quantization_config(model, qconfig2) + for name, module in model.named_modules(): + if is_match( + name, module, qconfig2.config_groups["group_0"].targets, qconfig2.ignore + ): + # assert W4A16_ASYM parameters are present with correct shape + # and FP8 parameters have been removed + assert not hasattr(module, "input_scale") + assert not hasattr(module, "input_zero_point") + weight_scale = getattr(module, "weight_scale", None) + assert ( + weight_scale is not None + and weight_scale.shape[:-1] == module.weight.shape[:-1] + and weight_scale.shape[-1] == module.weight.shape[-1] / 128 + ) + weight_zero_point = getattr(module, "weight_zero_point", None) + assert ( + weight_zero_point is not None + and weight_zero_point.shape[:-1] == module.weight.shape[:-1] + and weight_zero_point.shape[-1] == module.weight.shape[-1] / 128 + ) + + elif is_match( + name, module, qconfig1.config_groups["group_0"].targets, qconfig1.ignore + ): + # assert FP8 scheme parameters are present with correct shape + input_scale = getattr(module, "input_scale", None) + assert input_scale is not None and input_scale.shape == torch.Size([1]) + input_zero_point = getattr(module, "input_zero_point", None) + assert ( + input_zero_point is not None + and input_zero_point.shape == torch.Size([1]) + ) + weight_scale = getattr(module, "weight_scale", None) + assert weight_scale is not None and weight_scale.shape == torch.Size([1]) + weight_zero_point = getattr(module, "weight_zero_point", None) + assert ( + weight_zero_point is not None + and weight_zero_point.shape == torch.Size([1]) + )