Skip to content

Commit 4b79cef

Browse files
committed
attention and kvcache transforms
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0914f6f commit 4b79cef

File tree

9 files changed

+396
-34
lines changed

9 files changed

+396
-34
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Callable, Optional
17+
18+
import torch
19+
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
20+
from compressed_tensors.quantization import (
21+
QuantizationArgs,
22+
QuantizationScheme,
23+
QuantizationStrategy,
24+
forward_quantize,
25+
)
26+
from compressed_tensors.quantization.lifecycle.initialize import (
27+
_initialize_scale_zero_point,
28+
)
29+
from compressed_tensors.utils import getattr_chain
30+
from compressed_tensors.utils.internal import InternalModule
31+
from torch.utils.hooks import RemovableHandle
32+
from transformers import AttentionInterface, PreTrainedModel
33+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
34+
35+
36+
__all__ = ["IMPL_ATTR", "QuantizedAttentionImpl"]
37+
38+
39+
IMPL_ATTR = "impl"
40+
_original_impl = "eager" # mutable
41+
42+
43+
class QuantizedAttentionImpl(InternalModule):
44+
def __init__(self, attn_module: torch.nn.Module):
45+
super().__init__()
46+
self.attn_module_container = [attn_module] # avoid circular reference
47+
self._qparams_initialized = False
48+
49+
def forward(
50+
self,
51+
module: torch.nn.Module,
52+
query: torch.Tensor,
53+
key: torch.Tensor,
54+
value: torch.Tensor,
55+
*args,
56+
**kwargs,
57+
):
58+
# quantization
59+
quant_args_attr = "quantization_scheme.input_activations"
60+
quant_args = getattr_chain(module, quant_args_attr, None)
61+
quant_enabled = getattr(module, "quantization_enabled", True)
62+
if quant_args is not None and quant_enabled and self._qparams_initialized:
63+
query = forward_quantize(module, query, "q", quant_args)
64+
65+
# original attention
66+
return ALL_ATTENTION_FUNCTIONS[_original_impl](
67+
module,
68+
query,
69+
key,
70+
value,
71+
*args,
72+
**kwargs,
73+
)
74+
75+
def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
76+
assert module is self.attn_module_container[0]
77+
scheme: Optional[QuantizationScheme] = getattr(
78+
module, "quantization_scheme", None
79+
)
80+
quant_args: Optional[QuantizationArgs] = getattr(
81+
scheme, "input_activations", None
82+
)
83+
84+
if (
85+
not self._qparams_initialized
86+
and quant_args is not None
87+
and not scheme.kv_cache_only
88+
):
89+
# TODO: use model.config.num_attention_heads to find query_size
90+
assert quant_args.strategy == QuantizationStrategy.TENSOR
91+
_initialize_scale_zero_point(module, "q", quant_args)
92+
self._qparams_initialized = True
93+
94+
95+
# ----- initialize ----- #
96+
97+
98+
def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs):
99+
if hasattr(module, IMPL_ATTR):
100+
return module.impl(module, *args, **kwargs)
101+
else:
102+
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs)
103+
104+
105+
def initialize_hooked_attention(
106+
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True
107+
):
108+
if not hasattr(module, IMPL_ATTR):
109+
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module))
110+
if model.config._attn_implementation != "ct_hooked_attention":
111+
# assumes only one model at a time
112+
global _original_impl
113+
_original_impl = model.config._attn_implementation
114+
115+
AttentionInterface.register("ct_hooked_attention", ct_hooked_attention)
116+
model.config._attn_implementation = "ct_hooked_attention"
117+
118+
impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
119+
if quantize:
120+
impl.initialize_qparams_once(model, module)
121+
122+
initialize_hooked_kv_cache(model, module, quantize=quantize)
123+
124+
125+
# ----- hooks ----- #
126+
127+
128+
def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
129+
"""
130+
Registers a forward pre-hook on `module.impl` that replaces the `query` argument
131+
with `hook(mod, query)` (handles both positional and keyword forms).
132+
"""
133+
impl = getattr(module, IMPL_ATTR)
134+
135+
def _hook(impl: QuantizedAttentionImpl, args, kwargs):
136+
bound = inspect.signature(module.forward).bind(*args, **kwargs)
137+
value = hook(module, bound.arguments["query"])
138+
if value is not None:
139+
bound.arguments["query"] = value
140+
141+
return bound.args, bound.kwargs
142+
143+
return impl.register_forward_pre_hook(_hook, with_kwargs=True)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from typing import Callable, Optional, Tuple
17+
18+
import torch
19+
import transformers
20+
from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
21+
from compressed_tensors.quantization.lifecycle.initialize import (
22+
_initialize_scale_zero_point,
23+
)
24+
from compressed_tensors.utils import getattr_chain
25+
from compressed_tensors.utils.internal import InternalModule
26+
from packaging import version
27+
from torch import Tensor
28+
from torch.utils.hooks import RemovableHandle
29+
from transformers import Cache, PreTrainedModel
30+
31+
32+
__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"]
33+
34+
35+
KV_CACHE_ATTR = "kv_cache"
36+
37+
38+
class QuantizedKVCache(InternalModule):
39+
def __init__(self, attn_module: torch.nn.Module):
40+
super().__init__()
41+
self.attn_module_container = [attn_module] # avoid nn.Module circular reference
42+
self.past_key_values: Optional[Cache] = None
43+
self._qparams_initialized = False
44+
45+
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
46+
return self(*args, **kwargs)
47+
48+
def forward(
49+
self,
50+
key_states: Tensor,
51+
value_states: Tensor,
52+
*args,
53+
**kwargs,
54+
) -> Tuple[Tensor, Tensor]:
55+
# quantization
56+
module = self.attn_module_container[0]
57+
quant_args_attr = "quantization_scheme.input_activations"
58+
quant_args = getattr_chain(module, quant_args_attr, None)
59+
quant_enabled = getattr(module, "quantization_enabled", True)
60+
if quant_args is not None and quant_enabled and self._qparams_initialized:
61+
key_states = forward_quantize(module, key_states, "k", quant_args)
62+
value_states = forward_quantize(module, value_states, "v", quant_args)
63+
64+
# original cache
65+
if self.past_key_values is not None:
66+
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs)
67+
else:
68+
ret = (key_states, value_states)
69+
70+
self.past_key_values = None
71+
return ret
72+
73+
def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
74+
assert module is self.attn_module_container[0]
75+
scheme = getattr(module, "quantization_scheme", None)
76+
quant_args = getattr(scheme, "input_activations", None)
77+
78+
if not self._qparams_initialized and quant_args is not None:
79+
# TODO: use model.config.num_key_value_heads to find key_size, value_size
80+
assert quant_args.strategy == QuantizationStrategy.TENSOR
81+
_initialize_scale_zero_point(module, "k", quant_args)
82+
_initialize_scale_zero_point(module, "v", quant_args)
83+
self._qparams_initialized = True
84+
85+
86+
# ----- initialize ----- #
87+
88+
89+
def initialize_hooked_kv_cache(
90+
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False
91+
):
92+
if not hasattr(module, KV_CACHE_ATTR):
93+
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module))
94+
module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True)
95+
96+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
97+
if quantize:
98+
kv_cache.initialize_qparams_once(model, module)
99+
100+
101+
def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs):
102+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
103+
_past_kv_name = (
104+
"past_key_values" # transformers#39956
105+
if "past_key_values" in inspect.signature(module.forward).parameters
106+
else "past_key_value"
107+
)
108+
kv_cache.past_key_values = kwargs.get(_past_kv_name, None)
109+
kwargs[_past_kv_name] = kv_cache
110+
111+
return args, kwargs
112+
113+
114+
# ----- hooks ----- #
115+
116+
117+
def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
118+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
119+
120+
def _hook(cache: QuantizedKVCache, args, kwargs):
121+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
122+
value = hook(module, bound.arguments["key_states"])
123+
if value is not None:
124+
bound.arguments["key_states"] = value
125+
126+
return bound.args, bound.kwargs
127+
128+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
129+
130+
131+
def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
132+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
133+
134+
def _hook(cache: QuantizedKVCache, args, kwargs):
135+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
136+
value = hook(module, bound.arguments["value_states"])
137+
if value is not None:
138+
bound.arguments["value_states"] = value
139+
140+
return bound.args, bound.kwargs
141+
142+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)

src/compressed_tensors/transform/factory/base.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
import torch
2020
import torch.nn.utils.parametrize as P
2121
import tqdm
22+
from compressed_tensors.modeling.attention import (
23+
initialize_hooked_attention,
24+
register_query_hook,
25+
)
26+
from compressed_tensors.modeling.kvcache import (
27+
initialize_hooked_kv_cache,
28+
register_key_hook,
29+
)
2230
from compressed_tensors.registry.registry import RegistryMixin, T
2331
from compressed_tensors.transform import (
2432
TransformArgs,
@@ -37,6 +45,7 @@
3745
from compressed_tensors.utils.internal import InternalModule
3846
from torch import Tensor
3947
from torch.nn import Module, Parameter
48+
from transformers import PreTrainedModel
4049

4150

4251
__all__ = ["TransformFactory", "TransformBase"]
@@ -85,7 +94,7 @@ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBas
8594
"""
8695
raise NotImplementedError()
8796

88-
def apply_to_model(self, model: Module, use_tqdm=True):
97+
def apply_to_model(self, model: PreTrainedModel, use_tqdm=True):
8998
"""
9099
Create transforms and apply them to the model
91100
@@ -99,11 +108,13 @@ def apply_to_model(self, model: Module, use_tqdm=True):
99108

100109
desc = f"Applying {self.name} transforms"
101110
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
102-
self._apply_to_module(module, arg)
111+
self._apply_to_module(module, arg, model)
103112

104113
self._update_tied_weights()
105114

106-
def _apply_to_module(self, module: Module, args: TransformArgs):
115+
def _apply_to_module(
116+
self, module: Module, args: TransformArgs, model: PreTrainedModel
117+
):
107118
"""
108119
Create transforms and apply them to the module
109120
@@ -161,7 +172,24 @@ def output_hook(_, _input, output):
161172

162173
module.register_forward_hook(output_hook)
163174

164-
# other locations such as q_attn and k_attn have not been implemented
175+
# register query hook to attention
176+
elif args.location == TransformLocation.Q_ATTN:
177+
initialize_hooked_attention(model, module, quantize=False)
178+
179+
def query_hook(_, query_states):
180+
return transform(query_states)
181+
182+
register_query_hook(module, query_hook)
183+
184+
# register key hook to kvcache
185+
elif args.location == TransformLocation.K_CACHE:
186+
initialize_hooked_kv_cache(model, module, quantize=False)
187+
188+
def key_hook(_, key_states):
189+
return transform(key_states)
190+
191+
register_key_hook(module, key_hook)
192+
165193
else:
166194
raise NotImplementedError()
167195

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param module: parent module that transform will be applied to
5252
:param args: defines how the transform will be applied to the module
5353
"""
54-
assert hasattr(module, "weight")
5554
size = get_transform_size(module, args.location, self.scheme.head_dim)
5655
exec_device = get_execution_device(module)
5756
device = get_offloaded_device(module)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def create_transform(self, module: Module, args: TransformArgs):
5050
:param module: parent module that transform will be applied to
5151
:param args: defines how the transform will be applied to the module
5252
"""
53-
assert hasattr(module, "weight")
5453
size = get_transform_size(module, args.location, self.scheme.head_dim)
5554
device = get_offloaded_device(module)
5655
precision = self.scheme.precision if args.is_online() else torch.float64

0 commit comments

Comments
 (0)