Skip to content

Commit e224a5d

Browse files
committed
add kv and attention
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b5dc1e9 commit e224a5d

File tree

9 files changed

+495
-33
lines changed

9 files changed

+495
-33
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Callable, Optional
17+
from weakref import ref
18+
19+
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
20+
from compressed_tensors.quantization import (
21+
QuantizationArgs,
22+
QuantizationScheme,
23+
QuantizationStrategy,
24+
forward_quantize,
25+
)
26+
from compressed_tensors.quantization.lifecycle.initialize import (
27+
_initialize_scale_zero_point,
28+
)
29+
from compressed_tensors.utils import getattr_chain
30+
from compressed_tensors.utils.internal import InternalModule
31+
from torch import Tensor
32+
from torch.nn import Module
33+
from torch.utils.hooks import RemovableHandle
34+
from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel
35+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
36+
37+
38+
__all__ = [
39+
"QuantizedAttentionImpl",
40+
"initialize_hooked_attention",
41+
"register_query_hook",
42+
]
43+
44+
45+
IMPL_ATTR = "impl"
46+
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
47+
48+
49+
class QuantizedAttentionImpl(InternalModule):
50+
"""
51+
QuantizedAttentionImpl module which wraps the functionality of the original
52+
attention implementation. Unlike the original attention function, this
53+
implementation is a `torch.nn.Module` which can be hooked to trigger
54+
transforms and calibration hooks.
55+
56+
This module works by being registered as a submodule to attention modules via
57+
`initialize_hooked_attention`, registering a new attention implementation function
58+
which calls this module, then setting the model attention implementation to the new
59+
function. After triggering hooks and quantization, this module calls the original
60+
attention implementation function.
61+
62+
:param attn_module: parent attention module
63+
"""
64+
65+
def __init__(self, config: PretrainedConfig, attn_module: Module):
66+
super().__init__()
67+
self.config = config
68+
self.attn_module = ref(attn_module) # avoid circular references
69+
self._qparams_initialized = False
70+
71+
def forward(
72+
self,
73+
module: Module,
74+
query: Tensor,
75+
key: Tensor,
76+
value: Tensor,
77+
*args,
78+
**kwargs,
79+
):
80+
# quantization
81+
quant_args_attr = "quantization_scheme.input_activations"
82+
quant_args = getattr_chain(module, quant_args_attr, None)
83+
quant_enabled = getattr(module, "quantization_enabled", True)
84+
if quant_args is not None and quant_enabled and self._qparams_initialized:
85+
query = forward_quantize(module, query, "q", quant_args)
86+
87+
# original attention
88+
return ALL_ATTENTION_FUNCTIONS[_original_impl](
89+
module,
90+
query,
91+
key,
92+
value,
93+
*args,
94+
**kwargs,
95+
)
96+
97+
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
98+
"""
99+
Initialize attention quantization parameters if they have not already been
100+
initialized. KV cache quantization parameters are initialized by the
101+
`QuantizedKVCache`
102+
103+
:param model: parent model of attention module
104+
:param module: attention module to initialize with
105+
"""
106+
# TODO: move to initialize.py
107+
assert module is self.attn_module()
108+
scheme: Optional[QuantizationScheme] = getattr(
109+
module, "quantization_scheme", None
110+
)
111+
quant_args: Optional[QuantizationArgs] = getattr(
112+
scheme, "input_activations", None
113+
)
114+
115+
if (
116+
not self._qparams_initialized
117+
and quant_args is not None
118+
and not scheme.kv_cache_only
119+
):
120+
assert quant_args.strategy == QuantizationStrategy.TENSOR
121+
_initialize_scale_zero_point(module, "q", quant_args)
122+
self._qparams_initialized = True
123+
124+
125+
# ----- initialize ----- #
126+
127+
128+
def _ct_hooked_attention(module: Module, *args, **kwargs):
129+
if hasattr(module, IMPL_ATTR):
130+
return module.impl(module, *args, **kwargs)
131+
else:
132+
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs)
133+
134+
135+
def initialize_hooked_attention(
136+
model: PreTrainedModel, module: Module, quantize: bool = True
137+
):
138+
"""
139+
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
140+
attached to attention
141+
142+
:param model: parent model of attention module
143+
:param module: attention module to initialize with
144+
:param quantize: initialize attention quantization parameters
145+
"""
146+
if not hasattr(module, IMPL_ATTR):
147+
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module))
148+
if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
149+
# assumes only one model at a time
150+
global _original_impl
151+
_original_impl = model.config._attn_implementation
152+
153+
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
154+
model.config._attn_implementation = HOOKED_ATTENTION_NAME
155+
156+
impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
157+
if quantize:
158+
impl.initialize_qparams_once(model, module)
159+
160+
initialize_hooked_kv_cache(model, module, quantize=quantize)
161+
162+
163+
# ----- hooks ----- #
164+
165+
166+
def register_query_hook(
167+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
168+
) -> RemovableHandle:
169+
"""
170+
Register a hook which takes post-rope query states as an argument and
171+
returns the modified query states or `None`
172+
173+
:param module: attention module to add hook to
174+
:param hook: query hook function
175+
"""
176+
impl = getattr(module, IMPL_ATTR)
177+
178+
def _hook(impl: QuantizedAttentionImpl, args, kwargs):
179+
bound = inspect.signature(impl.forward).bind(*args, **kwargs)
180+
value = hook(module, bound.arguments["query"])
181+
if value is not None:
182+
bound.arguments["query"] = value
183+
184+
return bound.args, bound.kwargs
185+
186+
return impl.register_forward_pre_hook(_hook, with_kwargs=True)
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Callable, Optional, Tuple
17+
from weakref import ref
18+
19+
from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
20+
from compressed_tensors.quantization.lifecycle.initialize import (
21+
_initialize_scale_zero_point,
22+
)
23+
from compressed_tensors.utils import getattr_chain
24+
from compressed_tensors.utils.internal import InternalModule
25+
from torch import Tensor
26+
from torch.nn import Module
27+
from torch.utils.hooks import RemovableHandle
28+
from transformers import Cache, PretrainedConfig, PreTrainedModel
29+
30+
31+
__all__ = [
32+
"QuantizedKVCache",
33+
"initialize_hooked_kv_cache",
34+
"register_key_hook",
35+
"register_value_hook",
36+
]
37+
38+
39+
KV_CACHE_ATTR = "kv_cache"
40+
41+
42+
class QuantizedKVCache(InternalModule):
43+
"""
44+
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
45+
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
46+
hooked to trigger transforms and calibration hooks.
47+
48+
This module works by being registered as a submodule to attention modules via
49+
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
50+
kwargs with this module. This module adopts the functionality of the replaced cache,
51+
preserving caching functionality such as sliding window attention, ect.
52+
53+
:param attn_module: parent attention module
54+
"""
55+
56+
def __init__(self, config: PretrainedConfig, attn_module: Module):
57+
super().__init__()
58+
self.config = config
59+
self.attn_module = ref(attn_module) # avoid circular reference
60+
self.past_key_values: Optional[Cache] = None
61+
self._qparams_initialized = False
62+
63+
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
64+
return self(*args, **kwargs)
65+
66+
def forward(
67+
self,
68+
key_states: Tensor,
69+
value_states: Tensor,
70+
*args,
71+
**kwargs,
72+
) -> Tuple[Tensor, Tensor]:
73+
# quantization
74+
module = self.attn_module()
75+
quant_args_attr = "quantization_scheme.input_activations"
76+
quant_args = getattr_chain(module, quant_args_attr, None)
77+
quant_enabled = getattr(module, "quantization_enabled", True)
78+
if quant_args is not None and quant_enabled and self._qparams_initialized:
79+
key_states = forward_quantize(module, key_states, "k", quant_args)
80+
value_states = forward_quantize(module, value_states, "v", quant_args)
81+
82+
# original cache
83+
if self.past_key_values is not None:
84+
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs)
85+
else:
86+
ret = (key_states, value_states)
87+
88+
self.past_key_values = None
89+
return ret
90+
91+
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
92+
"""
93+
Initialize kv cache quantization parameters if they have not already been
94+
initialized
95+
96+
:param model: parent model of attention module
97+
:param module: attention module to initialize with
98+
"""
99+
# TODO: move to initialize.py
100+
assert module is self.attn_module()
101+
scheme = getattr(module, "quantization_scheme", None)
102+
quant_args = getattr(scheme, "input_activations", None)
103+
104+
if not self._qparams_initialized and quant_args is not None:
105+
assert quant_args.strategy == QuantizationStrategy.TENSOR
106+
_initialize_scale_zero_point(module, "k", quant_args)
107+
_initialize_scale_zero_point(module, "v", quant_args)
108+
self._qparams_initialized = True
109+
110+
111+
# ----- initialize ----- #
112+
113+
114+
def _kv_cache_attention_hook(module: Module, args, kwargs):
115+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
116+
_past_kv_name = (
117+
"past_key_values" # transformers#39956
118+
if "past_key_values" in inspect.signature(module.forward).parameters
119+
else "past_key_value"
120+
)
121+
kv_cache.past_key_values = kwargs.get(_past_kv_name, None)
122+
kwargs[_past_kv_name] = kv_cache
123+
124+
return args, kwargs
125+
126+
127+
def initialize_hooked_kv_cache(
128+
model: PreTrainedModel, module: Module, quantize: bool = False
129+
):
130+
"""
131+
Initialize a `QuantizedKVCache` instance attached to attention
132+
133+
:param model: parent model of attention module
134+
:param module: attention module to initialize with
135+
:param quantize: initialize kv cache quantization parameters
136+
"""
137+
if not hasattr(module, KV_CACHE_ATTR):
138+
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
139+
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
140+
141+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
142+
if quantize:
143+
kv_cache.initialize_qparams_once(model, module)
144+
145+
146+
# ----- hooks ----- #
147+
148+
149+
def register_key_hook(
150+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
151+
) -> RemovableHandle:
152+
"""
153+
Register a hook which takes post-rope key states as an argument and
154+
returns the modified key states or `None`
155+
156+
:param module: attention module to add hook to
157+
:param hook: key hook function
158+
"""
159+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
160+
161+
def _hook(cache: QuantizedKVCache, args, kwargs):
162+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
163+
value = hook(module, bound.arguments["key_states"])
164+
if value is not None:
165+
bound.arguments["key_states"] = value
166+
167+
return bound.args, bound.kwargs
168+
169+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
170+
171+
172+
def register_value_hook(
173+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
174+
) -> RemovableHandle:
175+
"""
176+
Register a hook which takes value states as an argument and
177+
returns the modified value states or `None`
178+
179+
:param module: attention module to add hook to
180+
:param hook: value hook function
181+
"""
182+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
183+
184+
def _hook(cache: QuantizedKVCache, args, kwargs):
185+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
186+
value = hook(module, bound.arguments["value_states"])
187+
if value is not None:
188+
bound.arguments["value_states"] = value
189+
190+
return bound.args, bound.kwargs
191+
192+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)

0 commit comments

Comments
 (0)