|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import inspect |
| 16 | +from typing import Callable, Optional, Tuple |
| 17 | + |
| 18 | +import torch |
| 19 | +import transformers |
| 20 | +from compressed_tensors.quantization import QuantizationStrategy, forward_quantize |
| 21 | +from compressed_tensors.quantization.lifecycle.initialize import ( |
| 22 | + _initialize_scale_zero_point, |
| 23 | +) |
| 24 | +from compressed_tensors.utils import getattr_chain |
| 25 | +from compressed_tensors.utils.internal import InternalModule |
| 26 | +from packaging import version |
| 27 | +from torch import Tensor |
| 28 | +from torch.utils.hooks import RemovableHandle |
| 29 | +from transformers import Cache, PreTrainedModel |
| 30 | + |
| 31 | + |
| 32 | +__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"] |
| 33 | + |
| 34 | + |
| 35 | +KV_CACHE_ATTR = "kv_cache" |
| 36 | + |
| 37 | + |
| 38 | +class QuantizedKVCache(InternalModule): |
| 39 | + def __init__(self, attn_module: torch.nn.Module): |
| 40 | + super().__init__() |
| 41 | + self.attn_module_container = [attn_module] # avoid nn.Module circular reference |
| 42 | + self.past_key_values: Optional[Cache] = None |
| 43 | + self._qparams_initialized = False |
| 44 | + |
| 45 | + def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: |
| 46 | + return self(*args, **kwargs) |
| 47 | + |
| 48 | + def forward( |
| 49 | + self, |
| 50 | + key_states: Tensor, |
| 51 | + value_states: Tensor, |
| 52 | + *args, |
| 53 | + **kwargs, |
| 54 | + ) -> Tuple[Tensor, Tensor]: |
| 55 | + # quantization |
| 56 | + module = self.attn_module_container[0] |
| 57 | + quant_args_attr = "quantization_scheme.input_activations" |
| 58 | + quant_args = getattr_chain(module, quant_args_attr, None) |
| 59 | + quant_enabled = getattr(module, "quantization_enabled", True) |
| 60 | + if quant_args is not None and quant_enabled and self._qparams_initialized: |
| 61 | + key_states = forward_quantize(module, key_states, "k", quant_args) |
| 62 | + value_states = forward_quantize(module, value_states, "v", quant_args) |
| 63 | + |
| 64 | + # original cache |
| 65 | + if self.past_key_values is not None: |
| 66 | + ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) |
| 67 | + else: |
| 68 | + ret = (key_states, value_states) |
| 69 | + |
| 70 | + self.past_key_values = None |
| 71 | + return ret |
| 72 | + |
| 73 | + def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module): |
| 74 | + assert module is self.attn_module_container[0] |
| 75 | + scheme = getattr(module, "quantization_scheme", None) |
| 76 | + quant_args = getattr(scheme, "input_activations", None) |
| 77 | + |
| 78 | + if not self._qparams_initialized and quant_args is not None: |
| 79 | + # TODO: use model.config.num_key_value_heads to find key_size, value_size |
| 80 | + assert quant_args.strategy == QuantizationStrategy.TENSOR |
| 81 | + _initialize_scale_zero_point(module, "k", quant_args) |
| 82 | + _initialize_scale_zero_point(module, "v", quant_args) |
| 83 | + self._qparams_initialized = True |
| 84 | + |
| 85 | + |
| 86 | +# ----- initialize ----- # |
| 87 | + |
| 88 | + |
| 89 | +def initialize_hooked_kv_cache( |
| 90 | + model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False |
| 91 | +): |
| 92 | + if not hasattr(module, KV_CACHE_ATTR): |
| 93 | + module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module)) |
| 94 | + module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True) |
| 95 | + |
| 96 | + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) |
| 97 | + if quantize: |
| 98 | + kv_cache.initialize_qparams_once(model, module) |
| 99 | + |
| 100 | + |
| 101 | +def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs): |
| 102 | + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) |
| 103 | + _past_kv_name = ( |
| 104 | + "past_key_values" # transformers#39956 |
| 105 | + if "past_key_values" in inspect.signature(module.forward).parameters |
| 106 | + else "past_key_value" |
| 107 | + ) |
| 108 | + kv_cache.past_key_values = kwargs.get(_past_kv_name, None) |
| 109 | + kwargs[_past_kv_name] = kv_cache |
| 110 | + |
| 111 | + return args, kwargs |
| 112 | + |
| 113 | + |
| 114 | +# ----- hooks ----- # |
| 115 | + |
| 116 | + |
| 117 | +def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: |
| 118 | + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) |
| 119 | + |
| 120 | + def _hook(cache: QuantizedKVCache, args, kwargs): |
| 121 | + bound = inspect.signature(cache.forward).bind(*args, **kwargs) |
| 122 | + value = hook(module, bound.arguments["key_states"]) |
| 123 | + if value is not None: |
| 124 | + bound.arguments["key_states"] = value |
| 125 | + |
| 126 | + return bound.args, bound.kwargs |
| 127 | + |
| 128 | + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) |
| 129 | + |
| 130 | + |
| 131 | +def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle: |
| 132 | + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) |
| 133 | + |
| 134 | + def _hook(cache: QuantizedKVCache, args, kwargs): |
| 135 | + bound = inspect.signature(cache.forward).bind(*args, **kwargs) |
| 136 | + value = hook(module, bound.arguments["value_states"]) |
| 137 | + if value is not None: |
| 138 | + bound.arguments["value_states"] = value |
| 139 | + |
| 140 | + return bound.args, bound.kwargs |
| 141 | + |
| 142 | + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) |
0 commit comments