15
15
import inspect
16
16
from typing import Callable , Optional , Tuple
17
17
18
- import torch
19
- import transformers
20
18
from compressed_tensors .quantization import QuantizationStrategy , forward_quantize
21
19
from compressed_tensors .quantization .lifecycle .initialize import (
22
20
_initialize_scale_zero_point ,
23
21
)
24
22
from compressed_tensors .utils import getattr_chain
25
23
from compressed_tensors .utils .internal import InternalModule
26
- from packaging import version
27
24
from torch import Tensor
25
+ from torch .nn import Module
28
26
from torch .utils .hooks import RemovableHandle
29
27
from transformers import Cache , PreTrainedModel
30
28
31
29
32
- __all__ = ["KV_CACHE_ATTR" , "QuantizedKVCache" ]
30
+ __all__ = [
31
+ "QuantizedKVCache" ,
32
+ "initialize_hooked_kv_cache" ,
33
+ "register_key_hook" ,
34
+ "register_value_hook" ,
35
+ ]
33
36
34
37
35
38
KV_CACHE_ATTR = "kv_cache"
36
39
37
40
38
41
class QuantizedKVCache (InternalModule ):
39
- def __init__ (self , attn_module : torch .nn .Module ):
42
+ """
43
+ QuantizedKVCache module which wraps the functionality of any existing kvcache args.
44
+ Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
45
+ hooked to trigger transforms and calibration hooks.
46
+
47
+ This module works by being registered as a submodule to attention modules via
48
+ `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
49
+ kwargs with this module. This module adopts the functionality of the replaced cache,
50
+ preserving caching functionality such as sliding window attention, ect.
51
+
52
+ :param attn_module: parent attention module
53
+ """
54
+
55
+ def __init__ (self , attn_module : Module ):
40
56
super ().__init__ ()
41
- self .attn_module_container = [attn_module ] # avoid nn.Module circular reference
57
+ self .attn_module_container = [attn_module ] # avoid circular reference
42
58
self .past_key_values : Optional [Cache ] = None
43
59
self ._qparams_initialized = False
44
60
@@ -70,13 +86,19 @@ def forward(
70
86
self .past_key_values = None
71
87
return ret
72
88
73
- def initialize_qparams_once (self , model : PreTrainedModel , module : torch .nn .Module ):
89
+ def initialize_qparams_once (self , model : PreTrainedModel , module : Module ):
90
+ """
91
+ Initialize kv cache quantization parameters if they have not already been
92
+ intialized
93
+
94
+ :param model: parent model of attention module
95
+ :param module: attention module to initialize with
96
+ """
74
97
assert module is self .attn_module_container [0 ]
75
98
scheme = getattr (module , "quantization_scheme" , None )
76
99
quant_args = getattr (scheme , "input_activations" , None )
77
100
78
101
if not self ._qparams_initialized and quant_args is not None :
79
- # TODO: use model.config.num_key_value_heads to find key_size, value_size
80
102
assert quant_args .strategy == QuantizationStrategy .TENSOR
81
103
_initialize_scale_zero_point (module , "k" , quant_args )
82
104
_initialize_scale_zero_point (module , "v" , quant_args )
@@ -86,19 +108,7 @@ def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Modul
86
108
# ----- initialize ----- #
87
109
88
110
89
- def initialize_hooked_kv_cache (
90
- model : PreTrainedModel , module : torch .nn .Module , quantize : bool = False
91
- ):
92
- if not hasattr (module , KV_CACHE_ATTR ):
93
- module .register_module (KV_CACHE_ATTR , QuantizedKVCache (module ))
94
- module .register_forward_pre_hook (kv_cache_attention_hook , with_kwargs = True )
95
-
96
- kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
97
- if quantize :
98
- kv_cache .initialize_qparams_once (model , module )
99
-
100
-
101
- def kv_cache_attention_hook (module : torch .nn .Module , args , kwargs ):
111
+ def _kv_cache_attention_hook (module : Module , args , kwargs ):
102
112
kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
103
113
_past_kv_name = (
104
114
"past_key_values" # transformers#39956
@@ -111,10 +121,38 @@ def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs):
111
121
return args , kwargs
112
122
113
123
124
+ def initialize_hooked_kv_cache (
125
+ model : PreTrainedModel , module : Module , quantize : bool = False
126
+ ):
127
+ """
128
+ Initialize a `QuantizedKVCache` instance attached to attention
129
+
130
+ :param model: parent model of attention module
131
+ :param module: attention module to initialize with
132
+ :param quantize: initialize kv cache quantization parameters
133
+ """
134
+ if not hasattr (module , KV_CACHE_ATTR ):
135
+ module .register_module (KV_CACHE_ATTR , QuantizedKVCache (module ))
136
+ module .register_forward_pre_hook (_kv_cache_attention_hook , with_kwargs = True )
137
+
138
+ kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
139
+ if quantize :
140
+ kv_cache .initialize_qparams_once (model , module )
141
+
142
+
114
143
# ----- hooks ----- #
115
144
116
145
117
- def register_key_hook (module : torch .nn .Module , hook : Callable ) -> RemovableHandle :
146
+ def register_key_hook (
147
+ module : Module , hook : Callable [[Module , Tensor ], Optional [Tensor ]]
148
+ ) -> RemovableHandle :
149
+ """
150
+ Register a hook which takes post-rope key states as an argument and
151
+ returns the modified key states or `None`
152
+
153
+ :param module: attention module to add hook to
154
+ :param hook: key hook function
155
+ """
118
156
kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
119
157
120
158
def _hook (cache : QuantizedKVCache , args , kwargs ):
@@ -128,7 +166,16 @@ def _hook(cache: QuantizedKVCache, args, kwargs):
128
166
return kv_cache .register_forward_pre_hook (_hook , with_kwargs = True )
129
167
130
168
131
- def register_value_hook (module : torch .nn .Module , hook : Callable ) -> RemovableHandle :
169
+ def register_value_hook (
170
+ module : Module , hook : Callable [[Module , Tensor ], Optional [Tensor ]]
171
+ ) -> RemovableHandle :
172
+ """
173
+ Register a hook which takes value states as an argument and
174
+ returns the modified value states or `None`
175
+
176
+ :param module: attention module to add hook to
177
+ :param hook: value hook function
178
+ """
132
179
kv_cache : QuantizedKVCache = getattr (module , KV_CACHE_ATTR )
133
180
134
181
def _hook (cache : QuantizedKVCache , args , kwargs ):
0 commit comments