diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 000000000..13c9827ad --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional + +import torch +from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + forward_quantize, +) +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, +) +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +__all__ = ["IMPL_ATTR", "QuantizedAttentionImpl"] + + +IMPL_ATTR = "impl" +_original_impl = "eager" # mutable + + +class QuantizedAttentionImpl(InternalModule): + def __init__(self, attn_module: torch.nn.Module): + super().__init__() + self.attn_module_container = [attn_module] # avoid circular reference + self._qparams_initialized = False + + def forward( + self, + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *args, + **kwargs, + ): + # quantization + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled and self._qparams_initialized: + query = forward_quantize(module, query, "q", quant_args) + + # original attention + return ALL_ATTENTION_FUNCTIONS[_original_impl]( + module, + query, + key, + value, + *args, + **kwargs, + ) + + def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): + assert module is self.attn_module_container[0] + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + quant_args: Optional[QuantizationArgs] = getattr( + scheme, "input_activations", None + ) + + if ( + not self._qparams_initialized + and quant_args is not None + and not scheme.kv_cache_only + ): + # TODO: use model.config.num_attention_heads to find query_size + assert quant_args.strategy in ( + QuantizationStrategy.TENSOR, + QuantizationStrategy.TOKEN, + QuantizationStrategy.ATTN_HEAD, + ) + + num_heads = model.config.num_attention_heads + hidden_size = model.config.hidden_size + observed_dtype = next(module.parameters()).dtype + _initialize_scale_zero_point( + module, + "q", + quant_args, + observed_shape=(num_heads, hidden_size), + observed_dtype=observed_dtype, + force_zero_point=True, + ) + self._qparams_initialized = True + + +# ----- initialize ----- # + + +def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs): + if hasattr(module, IMPL_ATTR): + return module.impl(module, *args, **kwargs) + else: + return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) + + +def initialize_hooked_attention( + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True +): + if not hasattr(module, IMPL_ATTR): + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module)) + if model.config._attn_implementation != "ct_hooked_attention": + # assumes only one model at a time + global _original_impl + _original_impl = model.config._attn_implementation + + AttentionInterface.register("ct_hooked_attention", ct_hooked_attention) + model.config._attn_implementation = "ct_hooked_attention" + + impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR) + if quantize: + impl.initialize_qparams_once(model, module) + + initialize_hooked_kv_cache(model, module, quantize=quantize) + + +# ----- hooks ----- # + + +def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + """ + Registers a forward pre-hook on `module.impl` that replaces the `query` argument + with `hook(mod, query)` (handles both positional and keyword forms). + """ + impl = getattr(module, IMPL_ATTR) + + def _hook(impl: QuantizedAttentionImpl, args, kwargs): + bound = inspect.signature(module.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["query"]) + if value is not None: + bound.arguments["query"] = value + + return bound.args, bound.kwargs + + return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py new file mode 100644 index 000000000..084da9085 --- /dev/null +++ b/src/compressed_tensors/modeling/kvcache.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional, Tuple + +import torch +import transformers +from compressed_tensors.quantization import QuantizationStrategy, forward_quantize +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, +) +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from packaging import version +from torch import Tensor +from torch.utils.hooks import RemovableHandle +from transformers import Cache, PreTrainedModel + + +__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"] + + +KV_CACHE_ATTR = "kv_cache" + + +class QuantizedKVCache(InternalModule): + def __init__(self, attn_module: torch.nn.Module): + super().__init__() + self.attn_module_container = [attn_module] # avoid nn.Module circular reference + self.past_key_values: Optional[Cache] = None + self._qparams_initialized = False + + def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: + return self(*args, **kwargs) + + def forward( + self, + key_states: Tensor, + value_states: Tensor, + *args, + **kwargs, + ) -> Tuple[Tensor, Tensor]: + # quantization + module = self.attn_module_container[0] + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled and self._qparams_initialized: + key_states = forward_quantize(module, key_states, "k", quant_args) + value_states = forward_quantize(module, value_states, "v", quant_args) + + # original cache + if self.past_key_values is not None: + ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) + else: + ret = (key_states, value_states) + + self.past_key_values = None + return ret + + def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): + assert module is self.attn_module_container[0] + scheme = getattr(module, "quantization_scheme", None) + quant_args = getattr(scheme, "input_activations", None) + + if not self._qparams_initialized and quant_args is not None: + # TODO: use model.config.num_key_value_heads to find key_size, value_size + assert quant_args.strategy in ( + QuantizationStrategy.TENSOR, + QuantizationStrategy.TOKEN, + QuantizationStrategy.ATTN_HEAD, + ) + num_heads = model.config.num_key_value_heads + hidden_size = model.config.hidden_size + observed_dtype = next(module.parameters()).dtype + _initialize_scale_zero_point( + module, + "k", + quant_args, + observed_shape=(num_heads, hidden_size), + observed_dtype=observed_dtype, + force_zero_point=True, + ) + _initialize_scale_zero_point( + module, + "v", + quant_args, + observed_shape=(num_heads, hidden_size), + observed_dtype=observed_dtype, + force_zero_point=True, + ) + self._qparams_initialized = True + + +# ----- initialize ----- # + + +def initialize_hooked_kv_cache( + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False +): + if not hasattr(module, KV_CACHE_ATTR): + module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module)) + module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True) + + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + if quantize: + kv_cache.initialize_qparams_once(model, module) + + +def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + _past_kv_name = ( + "past_key_value" + if version.parse(transformers.__version__) <= version.parse("4.55.4") + else "past_key_values" # transformers#39956 + ) + kv_cache.past_key_values = kwargs.get(_past_kv_name, None) + kwargs[_past_kv_name] = kv_cache + + return args, kwargs + + +# ----- hooks ----- # + + +def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["key_states"]) + if value is not None: + bound.arguments["key_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) + + +def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["value_states"]) + if value is not None: + bound.arguments["value_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index faa48df20..3091d4648 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -13,11 +13,8 @@ # limitations under the License. import logging -from collections import OrderedDict from copy import deepcopy -from typing import Dict, Iterable, List, Optional -from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch from compressed_tensors.config import CompressionFormat @@ -26,24 +23,29 @@ ) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, + is_attention_module, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( - KV_CACHE_TARGETS, + ATTN_TARGETS, infer_quantization_status, - is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import deprecated, replace_module -from compressed_tensors.utils.match import match_named_modules, match_targets -from compressed_tensors.utils.offload import update_parameter_data -from compressed_tensors.utils.safetensors_load import get_safetensors_folder +from compressed_tensors.utils import ( + deprecated, + get_safetensors_folder, + is_narrow_match, + match_named_modules, + match_targets, + replace_module, + update_parameter_data, +) from safetensors import safe_open from torch.nn import Module +from transformers import PreTrainedModel __all__ = [ @@ -134,39 +136,36 @@ def apply_quantization_config( # preprocess to support kv cache scheme config = process_quantization_config(config) - # build mapping of targets to schemes for easier matching - # use ordered dict to preserve target ordering in config - target_to_scheme = OrderedDict() - for scheme in config.config_groups.values(): - for target in scheme.targets: - target_to_scheme[target] = scheme - # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in match_named_modules( - model, target_to_scheme, config.ignore, warn_on_fail=True - ): - # mark modules to be quantized by adding - # quant scheme to the matching layers - matched_targets = match_targets(name, submodule, target_to_scheme) - scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - # replace with run compressed if applicable - # FUTURE: move this to model compressor - if isinstance(submodule, torch.nn.Linear) and run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # apply current quantization status across all targeted layers + for scheme in config.config_groups.values(): + for name, submodule in match_named_modules( + model, scheme.targets, config.ignore or [], warn_on_fail=True + ): + # attach scheme (overwrite any existing) + setattr(submodule, "quantization_scheme", scheme) + + # replace with run compressed if applicable + # FUTURE: move this to model compressor + if isinstance(submodule, torch.nn.Linear): + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + # attention quantization and/or kv cache quantization + if is_attention_module(submodule) and not is_narrow_match( + model, scheme.targets, name + ): + # do not quantize attention unless specifically targeted + delattr(submodule, "quantization_scheme") + + # apply current quantization status across all targeted linear/embedding layers apply_quantization_status(model, config.quantization_status) @@ -183,9 +182,7 @@ def process_quantization_config(config: QuantizationConfig) -> QuantizationConfi return config -def process_kv_cache_config( - config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS -) -> QuantizationConfig: +def process_kv_cache_config(config: QuantizationConfig) -> QuantizationConfig: """ Reformulate the `config.kv_cache` as a `config_group` and add it to the set of existing `config.groups` @@ -193,16 +190,14 @@ def process_kv_cache_config( :param config: the QuantizationConfig :return: the QuantizationConfig with additional "kv_cache" group """ - if targets == KV_CACHE_TARGETS: - _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}") + _LOGGER.info(f"KV cache targets set to default value of: {ATTN_TARGETS}") - kv_cache_dict = config.kv_cache_scheme.model_dump() - kv_cache_scheme = QuantizationScheme( - output_activations=QuantizationArgs(**kv_cache_dict), - targets=targets, + scheme = QuantizationScheme( + targets=ATTN_TARGETS, + input_activations=config.kv_cache_scheme, + kv_cache_only=True, ) - kv_cache_group = dict(kv_cache=kv_cache_scheme) - config.config_groups.update(kv_cache_group) + config.config_groups.update({"kv_cache": scheme}) return config @@ -221,7 +216,9 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply( lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init + module, + force_zero_point=force_zero_point_init, + model=model, ) ) @@ -254,14 +251,6 @@ def find_name_or_class_matches( return match_targets(name, module, targets) -def _infer_status(model: Module) -> Optional[QuantizationStatus]: - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def _load_quant_args_from_mapping( base_name: str, module_name: str, module: Module, mapping: Dict ): @@ -304,67 +293,3 @@ def _load_quant_args_from_mapping( state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") update_parameter_data(module, state_dict_zp, zp_name) - - -def _scheme_from_targets( - target_to_scheme: OrderedDictType[str, QuantizationScheme], - targets: List[str], - name: str, -) -> QuantizationScheme: - if len(targets) == 1: - # if `targets` iterable contains a single element - # use it as the key - return target_to_scheme[targets[0]] - - # otherwise, we need to merge QuantizationSchemes corresponding - # to multiple targets. This is most likely because `name` module - # is being target both as an ordinary quantization target, as well - # as kv cache quantization target - schemes_to_merge = [target_to_scheme[target] for target in targets] - return _merge_schemes(schemes_to_merge, name) - - -def _merge_schemes( - schemes_to_merge: List[QuantizationScheme], name: str -) -> QuantizationScheme: - kv_cache_quantization_scheme = [ - scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) - ] - if not kv_cache_quantization_scheme: - # if the schemes_to_merge do not contain any - # kv cache QuantizationScheme - # return the first scheme (the prioritized one, - # since the order of schemes_to_merge matters) - return schemes_to_merge[0] - else: - # fetch the kv cache QuantizationScheme and the highest - # priority non-kv cache QuantizationScheme and merge them - kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] - quantization_scheme = [ - scheme - for scheme in schemes_to_merge - if not is_kv_cache_quant_scheme(scheme) - ][0] - schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] - merged_scheme = {} - for scheme in schemes_to_merge: - scheme_dict = { - k: v for k, v in scheme.model_dump().items() if v is not None - } - # when merging multiple schemes, the final target will be - # the `name` argument - hence erase the original targets - del scheme_dict["targets"] - # make sure that schemes do not "clash" with each other - overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) - if overlapping_keys: - raise ValueError( - f"The module: {name} is being modified by two clashing " - f"quantization schemes, that jointly try to override " - f"properties: {overlapping_keys}. Fix the quantization config " - "so that it is not ambiguous." - ) - merged_scheme.update(scheme_dict) - - merged_scheme.update(targets=[name]) - - return QuantizationScheme(**merged_scheme) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 2e539b070..850d8f1e5 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -264,8 +264,7 @@ def _process_quantization( ): output_dtype = dtype if dtype is not None else x.dtype - output = torch.zeros_like(x).to(output_dtype) - columns = output.shape[-1] + columns = x.size(-1) # TODO: make validation step for inputs @@ -281,17 +280,8 @@ def _process_quantization( f"by the given group_size {group_size}" ) - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has been initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - num_groups = int(ceil(columns / group_size)) - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - + # permute groups + if g_idx is not None: perm = torch.argsort(g_idx) x = x.index_select(-1, perm) @@ -300,6 +290,8 @@ def _process_quantization( ceil(x.shape[-1] / group_size), group_size, ) + # we should potentially be folding reshaped_dims[0] into x.shape[-2] + # in order to allow for multi-headed activations x = x.unflatten(-1, reshaped_dims) if do_quantize: @@ -323,12 +315,12 @@ def _process_quantization( global_scale=global_scale, ) - output = output.flatten(start_dim=-2) + output = output.flatten(-2, -1) output = output.to(output_dtype) - if not is_column_order: - inv_perm = torch.argsort(perm) - output = output.index_select(-1, inv_perm) + # unpermute groups + if g_idx is not None: + x = x.index_select(-1, g_idx) else: # covers channel, token and tensor strategies if do_quantize: diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5350b4a2c..04b066744 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -14,10 +14,8 @@ import logging -import math -import warnings from enum import Enum -from typing import Optional +from typing import Optional, Tuple import torch from compressed_tensors.quantization.lifecycle.forward import ( @@ -26,6 +24,7 @@ from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, ActivationOrdering, + DynamicType, QuantizationArgs, QuantizationStrategy, ) @@ -38,6 +37,7 @@ register_offload_parameter, ) from torch.nn import Module, Parameter +from transformers import PreTrainedModel __all__ = [ @@ -59,6 +59,7 @@ def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, force_zero_point: bool = True, + model: Optional[PreTrainedModel] = None, ): """ attaches appropriate scales, zero points, and observers to a layer @@ -73,49 +74,64 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ - # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: - # no scheme passed and layer not targeted for quantization - skip return if is_attention_module(module): - # quantized actions based on calltime status - _initialize_attn_scales(module) + from compressed_tensors.modeling.attention import initialize_hooked_attention + + if not isinstance(model, PreTrainedModel): + raise ValueError("Must pass model in order to initialize attention") + initialize_hooked_attention(model, module, quantize=True) else: + if not isinstance(module, torch.nn.Linear): + _LOGGER.warning(f"Attempting to quantize module of type {type(module)}") + + # use weight to determine observed shapes and dtype + if hasattr(module, "weight"): + weight = module.weight + assert isinstance(weight, torch.Tensor) + else: + # Note that a weight is required for both weight and activation + # quantization in order to know the dtype of activation scales + _LOGGER.warning( + f"module type {type(module)} targeted for quantization but " + f"has no attribute weight, skipping quantization for {type(module)}" + ) + return + if scheme.input_activations is not None: _initialize_scale_zero_point( module, "input", scheme.input_activations, + observed_shape=(1, weight.shape[-1]), + observed_dtype=weight.dtype, force_zero_point=force_zero_point, ) if scheme.weights is not None: - if hasattr(module, "weight"): - weight_shape = None - if isinstance(module, torch.nn.Linear): - weight_shape = module.weight.shape - _initialize_scale_zero_point( - module, - "weight", - scheme.weights, - weight_shape=weight_shape, - force_zero_point=force_zero_point, - ) - else: - _LOGGER.warning( - f"module type {type(module)} targeted for weight quantization but " - "has no attribute weight, skipping weight quantization " - f"for {type(module)}" - ) - - if scheme.output_activations is not None: - if not is_kv_cache_quant_scheme(scheme): - _initialize_scale_zero_point( - module, "output", scheme.output_activations - ) + _initialize_scale_zero_point( + module, + "weight", + scheme.weights, + observed_shape=weight.shape, + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) + + output_is_kv_cache = is_kv_cache_quant_scheme(scheme) + if scheme.output_activations is not None and not output_is_kv_cache: + _initialize_scale_zero_point( + module, + "output", + scheme.output_activations, + observed_shape=weight.shape[:-1], + observed_dtype=weight.dtype, + force_zero_point=force_zero_point, + ) module.quantization_scheme = scheme module.quantization_status = QuantizationStatus.INITIALIZED @@ -138,18 +154,21 @@ def _initialize_scale_zero_point( module: Module, base_name: str, quantization_args: QuantizationArgs, - weight_shape: Optional[torch.Size] = None, + observed_shape: Tuple[int], + observed_dtype: torch.dtype, force_zero_point: bool = True, ): - if quantization_args.dynamic is True: - return + strategy = quantization_args.strategy + dynamic = quantization_args.dynamic + actorder = quantization_args.actorder + device = get_execution_device(module) # avoid performing intialization ops on cpu - # initialize on execution device to avoid performing quantized ops on cpu - device = get_execution_device(module) + # Skip all intialization for fully dynamic quantization + if dynamic is True: + return - # 1. Create global_scales for tensor_group - generates - # a per tensor scale - if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: + # 0. Create global scale for tensor-group quantization + if strategy == QuantizationStrategy.TENSOR_GROUP: init_global_scale = Parameter( torch.empty(1, dtype=torch.float32, device=device), requires_grad=False, @@ -158,56 +177,58 @@ def _initialize_scale_zero_point( module, f"{base_name}_global_scale", init_global_scale ) - # 2. Infer expected scale/zero point shape - if quantization_args.strategy == QuantizationStrategy.TOKEN: + # Skip scale/zp initialization for locally dynamic quantization + if dynamic == DynamicType.LOCAL: + return + + # 1. Infer expected scale/zp shape + if strategy == QuantizationStrategy.TENSOR: + expected_shape = (1,) + + elif strategy == QuantizationStrategy.TOKEN: expected_shape = (1, 1) + + elif strategy == QuantizationStrategy.CHANNEL: + if len(observed_shape) < 1: + raise ValueError("Channel quant requires at least 1 observed dimension") + + expected_shape = (observed_shape[-2], 1) + + elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + assert quantization_args.group_size is not None + if len(observed_shape) < 1: + raise ValueError("Group quant requires at least 1 observed dimension") + + group_size = quantization_args.group_size + num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy) + expected_shape = (*observed_shape[:-1], num_groups) + + # initialize activation ordering if applicable + if actorder == ActivationOrdering.GROUP: + init_g_idx = Parameter( + torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) + + elif strategy == QuantizationStrategy.BLOCK: + assert quantization_args.block_structure is not None + if len(observed_shape) < 2: + raise ValueError("Block quant requires at least 2 observed dimensions") + + block_structure = quantization_args.block_structure + num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy) + num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy) + expected_shape = (num_rows, num_cols) + + elif strategy == QuantizationStrategy.ATTN_HEAD: + expected_shape = (observed_shape[-2], 1) + else: - expected_shape = 1 - - if base_name == "weight" and weight_shape is not None: - if quantization_args.strategy == QuantizationStrategy.CHANNEL: - # (output_channels, 1) - only for weights - expected_shape = (weight_shape[0], 1) - elif quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - # GROUP/TENSOR_GROUP for both weights and activations - num_groups = math.ceil(weight_shape[1] / quantization_args.group_size) - expected_shape = (weight_shape[0], max(num_groups, 1)) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - # For block quantization, scale shape should match number of blocks - only - # for weights - if quantization_args.block_structure is None: - raise ValueError( - "Block quantization requires block_structure to be specified" - ) - block_height, block_width = quantization_args.block_structure - rows, cols = weight_shape[-2], weight_shape[-1] - num_rows_blocks = math.ceil(rows / block_height) - num_cols_blocks = math.ceil(cols / block_width) - - # Warn if dimensions don't divide evenly - if rows % block_height != 0 or cols % block_width != 0: - warnings.warn( - f"Block quantization: tensor shape {weight_shape} does not divide" - f"evenly by block structure {quantization_args.block_structure}. " - f"Some blocks will be incomplete which may affect quantization" - "quality.", - UserWarning, - ) - - expected_shape = (num_rows_blocks, num_cols_blocks) - elif quantization_args.strategy == QuantizationStrategy.BLOCK: - warnings.warn( - f"BLOCK quantization not supported for {base_name} activations. " - f"Falling back to tensor-level quantization.", - UserWarning, - ) - expected_shape = 1 + assert False, f"Unknown strategy {strategy}" - # 3. Identify quantization scale and zp dtype - scale_dtype = module.weight.dtype + # 2. Identify quantization scale and zp dtype + scale_dtype = observed_dtype if is_fp4(quantization_args=quantization_args): scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype @@ -223,14 +244,12 @@ def _initialize_scale_zero_point( scale_dtype = torch.bfloat16 zp_dtype = quantization_args.pytorch_dtype() - # 4. Initializes empty scale, zero point, and g_idx parameters for the module - # do not init scales for quantzation_args.dynamic == DynamicType.local - if not quantization_args.dynamic: - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_scale", init_scale) + # 3. Initializes scale/zp for the module + init_scale = Parameter( + torch.empty(expected_shape, dtype=scale_dtype, device=device), + requires_grad=False, + ) + register_offload_parameter(module, f"{base_name}_scale", init_scale) if force_zero_point or not quantization_args.symmetric: init_zero_point = Parameter( @@ -238,35 +257,3 @@ def _initialize_scale_zero_point( requires_grad=False, ) register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point) - - # only grouped activation ordering has g_idx - if quantization_args.actorder == ActivationOrdering.GROUP: - g_idx_shape = (weight_shape[1],) - g_idx_dtype = torch.int - init_g_idx = Parameter( - torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype), - requires_grad=False, - ) - register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx) - - -def _initialize_attn_scales(module: Module) -> None: - """Initlaize k_scale, v_scale for self_attn""" - - expected_shape = 1 # per tensor - - param = next(module.parameters()) - scale_dtype = param.dtype - device = param.device - - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.KEY.value, init_scale) - - init_scale = Parameter( - torch.empty(expected_shape, dtype=scale_dtype, device=device), - requires_grad=False, - ) - register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9e88353b..1e4e1bf41 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -101,6 +101,7 @@ class QuantizationStrategy(str, Enum): BLOCK = "block" TOKEN = "token" TENSOR_GROUP = "tensor_group" + ATTN_HEAD = "attn_head" class DynamicType(str, Enum): @@ -262,6 +263,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": actorder = model.actorder dynamic = model.dynamic observer = model.observer + block_structure = model.block_structure # infer strategy if strategy is None: @@ -277,23 +279,29 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": "strategy='group' and group_size = -1 for 'channel'" ) - # validate strategy and group - if strategy == QuantizationStrategy.GROUP: - if group_size is None or group_size <= 0: - raise ValueError( - f"strategy {strategy} requires group_size to be " - "set to a positive value" - ) - if ( - group_size is not None - and group_size > 0 - and strategy - not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP) - ): - raise ValueError("group_size requires strategy to be set to 'group'") - - # validate activation ordering and strategy - if actorder is not None and strategy != QuantizationStrategy.GROUP: + # validate block strategy and structure + has_block_strategy = strategy == QuantizationStrategy.BLOCK + has_block_structure = block_structure is not None + if has_block_strategy != has_block_structure: + raise ValueError( + "`strategy = block` requires `block_structure != None`, and vice versa." + f" Instead got `strategy={strategy}` and " + f"`block_structure={block_structure}`" + ) + + # validate group strategy + has_group_strategy = strategy in ( + QuantizationStrategy.GROUP, + QuantizationStrategy.TENSOR_GROUP, + ) + has_group_size = group_size is not None and group_size > 0 + has_actorder = actorder is not None + if has_group_strategy != has_group_size: + raise ValueError( + "`strategy = group` requires `group_size != None`, and vice versa. " + f"Instead got `strategy={strategy}` and `group_size={group_size}`" + ) + if has_actorder and not has_group_strategy: raise ValueError( "Must use group quantization strategy in order to apply " "activation ordering" @@ -356,6 +364,9 @@ def pytorch_dtype(self) -> torch.dtype: else: raise ValueError(f"Invalid quantization type {self.type}") + def is_online(self) -> bool: + return self.dynamic is True + @deprecated("QuantizationArgs.observer") def get_observer(self) -> str: return self.observer diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index b11e3c0c0..e00bb899a 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import List, Optional +from typing import Any, List, Optional from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( @@ -23,6 +23,7 @@ QuantizationStrategy, QuantizationType, ) +from loguru import logger from pydantic import BaseModel, ConfigDict, model_validator @@ -44,6 +45,7 @@ class QuantizationScheme(BaseModel): :param input_activations: quantization config for layer inputs :param output_activations: quantization config for layer outputs :param format: CompressionFormat for the layer + TODO """ targets: List[str] @@ -51,6 +53,7 @@ class QuantizationScheme(BaseModel): input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None format: Optional[str] = None + kv_cache_only: Optional[bool] = None @model_validator(mode="after") def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": @@ -60,24 +63,19 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": format = model.format if inputs is not None: - if inputs.strategy not in ( - QuantizationStrategy.TOKEN, - QuantizationStrategy.TENSOR, - QuantizationStrategy.GROUP, - QuantizationStrategy.TENSOR_GROUP, - ): - if ( - inputs.strategy == QuantizationStrategy.GROUP - and inputs.dynamic is True - ): - raise NotImplementedError( - "Static and local group-wise activation " - "quantization is not supported" - ) - - raise NotImplementedError( - f"Using {inputs.strategy} strategy is not supported for " - "activation quantization" + if inputs.strategy == QuantizationStrategy.CHANNEL: + raise ValueError( + "Channel-wise activation quantization is equivalent to " + "tensor/token-wise activation quantization, please use one of " + "those. If you mean to quantize each activation value " + "individually, please use group quantization with `group_size = 1`" + ) + + if inputs.strategy == QuantizationStrategy.BLOCK: + raise ValueError( + "Block-wise activation quantization is not supported. If you mean " + "to quantize each activation value individually, please use group " + "quantization with `group_size = 1`" ) if inputs.actorder is not None: diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d545193..630d1c776 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -39,7 +39,7 @@ "get_torch_bit_depth", "can_quantize", "parse_out_kv_cache_args", - "KV_CACHE_TARGETS", + "ATTN_TARGETS", "is_kv_cache_quant_scheme", "iter_named_leaf_modules", "iter_named_quantizable_modules", @@ -48,11 +48,11 @@ "calculate_qparams", "generate_gparam", "is_fp4", + "strict_divide", ] -# target the self_attn layer -# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale -KV_CACHE_TARGETS = ["re:.*self_attn$"] +# note that this is a "narrow match", see quantization/apply.py +ATTN_TARGETS = ["re:.*self_attn$"] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -409,17 +409,17 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: """ Check whether the QuantizationScheme targets the kv cache. It does if all the following criteria are met: - - the scheme targets either exactly match the KV_CACHE_TARGETS - or the match KV_CACHE_TARGETS regex pattern + - the scheme targets either exactly match the ATTN_TARGETS + or the match ATTN_TARGETS regex pattern - the scheme quantizes output_activations (we want to quantize the - outputs from the KV_CACHE_TARGETS, as their correspond to the + outputs from the ATTN_TARGETS, as their correspond to the keys and values that are to be saved in the cache) :param scheme: The QuantizationScheme to investigate :return: boolean flag """ for target in scheme.targets: - if target in KV_CACHE_TARGETS: + if target in ATTN_TARGETS: return True return False @@ -477,3 +477,19 @@ def generate_gparam( max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) global_scale = scale_data.max * quant_data.max / max_val_pos return global_scale.to(dtype).reshape([1]) + + +def strict_divide( + observed: int, divisor: int, strategy: Optional[QuantizationStrategy] = None +) -> int: + out = observed // divisor + if out * divisor != observed: + if strategy is not None: + raise ValueError( + f"{strategy} quantization strategy requires strict division of " + f"weight/activation size {observed} and group/block size {divisor}. " + "consider reducing the group/block size or ignoring modules with " + f"weights not divisible by {divisor}" + ) + + return out diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 94e6b4a42..8c56ca14a 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -19,6 +19,14 @@ import torch import torch.nn.utils.parametrize as P import tqdm +from compressed_tensors.modeling.attention import ( + initialize_hooked_attention, + register_query_hook, +) +from compressed_tensors.modeling.kvcache import ( + initialize_hooked_kv_cache, + register_key_hook, +) from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -37,6 +45,7 @@ from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.nn import Module, Parameter +from transformers import PreTrainedModel __all__ = ["TransformFactory", "TransformBase"] @@ -99,11 +108,13 @@ def apply_to_model(self, model: Module, use_tqdm=True): desc = f"Applying {self.name} transforms" for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): - self._apply_to_module(module, arg) + self._apply_to_module(module, arg, model) self._update_tied_weights() - def _apply_to_module(self, module: Module, args: TransformArgs): + def _apply_to_module( + self, module: Module, args: TransformArgs, model: PreTrainedModel + ): """ Create transforms and apply them to the module @@ -161,9 +172,25 @@ def output_hook(_, _input, output): module.register_forward_hook(output_hook) + elif args.location == TransformLocation.Q_ATTN: + initialize_hooked_attention(model, module, quantize=False) + + def query_hook(_, query_states): + return transform(query_states) + + register_query_hook(module, query_hook) + # other locations such as q_attn and k_attn have not been implemented + elif args.location == TransformLocation.K_CACHE: + initialize_hooked_kv_cache(model, module, quantize=False) + + def key_hook(_, key_states): + return transform(key_states) + + register_key_hook(module, key_hook) + else: - raise NotImplementedError() + assert False def _update_tied_weights(self): """ diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index de6e284bb..b827ffe2b 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -22,8 +22,12 @@ apply_transform_weight, get_transform_size, ) -from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict +from compressed_tensors.utils.offload import ( + get_execution_device, + get_offloaded_device, + has_offloaded_params, +) from torch import Tensor, device, dtype from torch.nn import Module, Parameter @@ -51,7 +55,6 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) exec_device = get_execution_device(module) device = get_offloaded_device(module) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index d3f469579..75c816492 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -45,6 +45,12 @@ class TransformLocation(str, Enum): K_CACHE = "k_cache" Q_ATTN = "q_attn" + def is_online(self) -> bool: + return self not in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ) + class TransformArgs(BaseModel, use_enum_values=True): """ @@ -70,9 +76,6 @@ def wrap_singleton(cls, value): return value def is_online(self) -> bool: - return self.location not in ( - TransformLocation.WEIGHT_INPUT, - TransformLocation.WEIGHT_OUTPUT, - ) + return TransformLocation(self.location).is_online() model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index 920728571..0414e3f69 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -34,6 +34,8 @@ def get_transform_size( :param head_dim: size of head when transform is applied to mha :return: size of matrix """ + size = None + if isinstance(module, torch.nn.Linear): if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): size = module.in_features @@ -44,11 +46,13 @@ def get_transform_size( size = module.num_embeddings else: size = module.embedding_dim - else: - raise NotImplementedError(f"Transforms on {type(module)} are not supported") + elif head_dim is None: + raise NotImplementedError( + f"Transforms on {type(module)} are not supported without head_dim" + ) if head_dim is not None: - if size % head_dim != 0: + if size is not None and size % head_dim != 0: raise ValueError( f"{head_dim} must divide {size} for {type(module)} at {location}" ) @@ -105,11 +109,11 @@ def apply_transform_weight( assert transform_weight.shape[0] == transform_weight.shape[1] - if module_type == torch.nn.Linear: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) + if TransformLocation(location).is_online(): + return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.WEIGHT_INPUT: + if module_type == torch.nn.Linear: + if location == TransformLocation.WEIGHT_INPUT: # equivalent to (transform_weight @ value.T).T return _multihead_matmul(value, transform_weight.T) @@ -117,26 +121,14 @@ def apply_transform_weight( # equivalent to (value.T @ transform_weight).T return _multihead_matmul(transform_weight.T, value) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - # similar derivation to torch.nn.Linear, but `y = (x W)` elif module_type == torch.nn.Embedding: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) - - elif location == TransformLocation.WEIGHT_INPUT: - return _multihead_matmul( - transform_weight, - value, - ) + if location == TransformLocation.WEIGHT_INPUT: + return _multihead_matmul(transform_weight, value) elif location == TransformLocation.WEIGHT_OUTPUT: return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - raise NotImplementedError( f"Applying transforms to {module_type} {location} is not supported" ) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 11e2a2a1c..3d423981a 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -30,6 +30,7 @@ "match_targets", "match_modules_set", "is_match", + "is_narrow_match", ] @@ -128,7 +129,6 @@ def match_targets( :param module: the module to match :param targets: the target strings, potentially containing "re:" prefixes :return: the targets that match the given name and module - Outputs are ordered by type: exact name match, regex name match, class name match """ targets = targets or [] @@ -305,3 +305,13 @@ def _match_class(module: torch.nn.Module, target: str) -> bool: ) for cls in module.__class__.__mro__ ) + + +def is_narrow_match(model: torch.nn.Module, targets: Iterable[str], name: str) -> bool: + module = model.get_submodule(name) + parent_name = name.rsplit(".", 1)[0] + parent = model.get_submodule(parent_name) + + return is_match(name, module, targets) and not is_match( + parent_name, parent, targets + ) diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 115cf3f5a..809d8248b 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -344,7 +344,8 @@ def _get_combined_config(s_config, q_config): ) def test_compress_model(model_stub, q_format, s_config, tmpdir): model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) - compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format]) + qformats = None if q_format is None else [q_format] # FUTURE: remove nullability + compressor = ModelCompressor.from_pretrained_model(model, s_config, qformats) # compress model by eagerly compressing state dict true_compressed = dict(compressor.compress(model)) @@ -391,8 +392,9 @@ def test_compress_model_meta(model_stub, q_format, s_config): cpu_model = AutoModelForCausalLM.from_pretrained( model_stub, torch_dtype=torch.float32 ) + qformats = None if q_format is None else [q_format] # FUTURE: remove nullability reference_compressor = ModelCompressor.from_pretrained_model( - cpu_model, s_config, [q_format] + cpu_model, s_config, qformats ) # Only stores dtype because meta model does not store values expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()} @@ -408,7 +410,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): module.to_empty(device="meta") # Compress in-place on meta model - compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format]) + compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, qformats) compressor.compress_model(meta_model) # Compare keys and dtypes diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index d5fd6c2cd..52b301ed5 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -57,53 +57,6 @@ def llama_stories_model(): ) -def test_target_prioritization(mock_frozen): - # tests that the config_groups are applied in the correct order - # of priority, where exact layer name > regex > module name - config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "config_groups": { - "group_1": { - "weights": { - "num_bits": 8, - }, - "targets": ["Linear"], - }, - "group_2": { - "weights": { - "num_bits": 4, - }, - "targets": ["re:.*down_proj"], - }, - "group_3": { - "weights": { - "num_bits": 2, - }, - "targets": ["model.layers.0.mlp.down_proj"], - }, - }, - } - - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", torch_dtype="auto" - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - mock_frozen(model) - - for name, module in model.named_modules(): - if name == "model.layers.0.mlp.down_proj": - assert module.quantization_scheme.weights.num_bits == 2 - elif re.match(".*down_proj", name): - assert module.quantization_scheme.weights.num_bits == 4 - elif isinstance(module, torch.nn.Linear): - assert module.quantization_scheme.weights.num_bits == 8 - - def test_apply_quantization_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config(status="calibration") model = get_tinyllama_model() @@ -174,13 +127,16 @@ def test_serialize_config_tinyllama(): serialized_config = QuantizationConfig.from_pretrained(model) assert len(serialized_config.config_groups) == 2 - assert serialized_config.config_groups["group_0"].targets == ["Embedding"] - assert serialized_config.config_groups["group_0"].input_activations is None assert serialized_config.config_groups["group_1"].targets == ["Linear"] assert serialized_config.config_groups["group_1"].input_activations is not None + assert serialized_config.config_groups["group_2"].targets == ["Embedding"] + assert serialized_config.config_groups["group_2"].input_activations is None assert serialized_config.format == CompressionFormat.dense.value assert serialized_config.quant_method == DEFAULT_QUANTIZATION_METHOD - assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] + assert serialized_config.ignore == [ + "LlamaRotaryEmbedding", + "model.layers.1.mlp.down_proj", + ] if serialized_config.global_compression_ratio is not None: assert serialized_config.global_compression_ratio > 1.0 assert serialized_config.global_compression_ratio < 8.0 diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index f3321cd40..09010af06 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -95,7 +95,7 @@ def test_forward_quantize( @pytest.mark.parametrize( - "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale", + "num_bits,type,strategy,group_size,scale,zero_point,g_idx,global_scale,batch_size", [ ( 4, @@ -106,6 +106,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 4, @@ -116,6 +117,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 4, @@ -126,6 +128,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -136,6 +139,7 @@ def test_forward_quantize( torch.zeros((1,)), None, None, + None, ), ( 8, @@ -146,6 +150,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -156,6 +161,7 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, ), ( 8, @@ -166,6 +172,7 @@ def test_forward_quantize( torch.zeros((512, 8)), None, None, + None, ), ( 8, @@ -176,17 +183,41 @@ def test_forward_quantize( torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), None, + None, + ), + ( + 8, + "int", + QuantizationStrategy.GROUP, + 128, + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), + make_dummy_g_idx(1024, 128), + None, + 5, ), ], ) -def test_fake_quantize_2d( - num_bits, type, strategy, group_size, scale, zero_point, g_idx, global_scale +def test_fake_quantize( + num_bits, + type, + strategy, + group_size, + scale, + zero_point, + g_idx, + global_scale, + batch_size, ): args = QuantizationArgs( num_bits=num_bits, type=type, strategy=strategy, group_size=group_size ) - x = torch.rand((512, 1024)) + if batch_size is None: + x = torch.rand((512, 1024)) + else: + x = torch.rand((batch_size, 512, 1024)) + fake_quantize( x=x, scale=scale, diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index c3830a02d..cbe07183a 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +from compressed_tensors import CompressionFormat from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_FORMAT, DEFAULT_QUANTIZATION_METHOD,