Skip to content

Commit 7bf4b57

Browse files
committed
docstrings
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b6636c8 commit 7bf4b57

File tree

3 files changed

+130
-39
lines changed

3 files changed

+130
-39
lines changed

src/compressed_tensors/modeling/attention.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import inspect
1616
from typing import Callable, Optional
1717

18-
import torch
1918
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
2019
from compressed_tensors.quantization import (
2120
QuantizationArgs,
@@ -28,30 +27,51 @@
2827
)
2928
from compressed_tensors.utils import getattr_chain
3029
from compressed_tensors.utils.internal import InternalModule
30+
from torch import Tensor
31+
from torch.nn import Module
3132
from torch.utils.hooks import RemovableHandle
3233
from transformers import AttentionInterface, PreTrainedModel
3334
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
3435

3536

36-
__all__ = ["IMPL_ATTR", "QuantizedAttentionImpl"]
37+
__all__ = [
38+
"QuantizedAttentionImpl",
39+
"initialize_hooked_attention",
40+
"register_query_hook",
41+
]
3742

3843

3944
IMPL_ATTR = "impl"
40-
_original_impl = "eager" # mutable
45+
_original_impl = "eager" # mutable, assumes only one model at a time
4146

4247

4348
class QuantizedAttentionImpl(InternalModule):
44-
def __init__(self, attn_module: torch.nn.Module):
49+
"""
50+
QuantizedAttentionImpl module which wraps the functionality of the original
51+
attention implementation. Unlike the original attention function, this
52+
implementation is a `torch.nn.Module` which can be hooked to trigger
53+
transforms and calibration hooks.
54+
55+
This module works by being registered as a submodule to attention modules via
56+
`initialize_hooked_attention`, registering a new attention implementation function
57+
which calls this module, then setting the model attention implementation to the new
58+
function. After triggering hooks and quantization, this module calls the original
59+
attention implementation function.
60+
61+
:param attn_module: parent attention module
62+
"""
63+
64+
def __init__(self, attn_module: Module):
4565
super().__init__()
4666
self.attn_module_container = [attn_module] # avoid circular reference
4767
self._qparams_initialized = False
4868

4969
def forward(
5070
self,
51-
module: torch.nn.Module,
52-
query: torch.Tensor,
53-
key: torch.Tensor,
54-
value: torch.Tensor,
71+
module: Module,
72+
query: Tensor,
73+
key: Tensor,
74+
value: Tensor,
5575
*args,
5676
**kwargs,
5777
):
@@ -72,7 +92,15 @@ def forward(
7292
**kwargs,
7393
)
7494

75-
def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
95+
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
96+
"""
97+
Initialize attention quantization parameters if they have not already been
98+
intialized. KV cache quantization parameters are initialized by the
99+
`QuantizedKVCache`
100+
101+
:param model: parent model of attention module
102+
:param module: attention module to initialize with
103+
"""
76104
assert module is self.attn_module_container[0]
77105
scheme: Optional[QuantizationScheme] = getattr(
78106
module, "quantization_scheme", None
@@ -86,7 +114,6 @@ def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Modul
86114
and quant_args is not None
87115
and not scheme.kv_cache_only
88116
):
89-
# TODO: use model.config.num_attention_heads to find query_size
90117
assert quant_args.strategy == QuantizationStrategy.TENSOR
91118
_initialize_scale_zero_point(module, "q", quant_args)
92119
self._qparams_initialized = True
@@ -95,24 +122,32 @@ def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Modul
95122
# ----- initialize ----- #
96123

97124

98-
def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs):
125+
def _ct_hooked_attention(module: Module, *args, **kwargs):
99126
if hasattr(module, IMPL_ATTR):
100127
return module.impl(module, *args, **kwargs)
101128
else:
102129
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs)
103130

104131

105132
def initialize_hooked_attention(
106-
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True
133+
model: PreTrainedModel, module: Module, quantize: bool = True
107134
):
135+
"""
136+
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
137+
attached to attention
138+
139+
:param model: parent model of attention module
140+
:param module: attention module to initialize with
141+
:param quantize: initialize attention quantization parameters
142+
"""
108143
if not hasattr(module, IMPL_ATTR):
109144
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module))
110145
if model.config._attn_implementation != "ct_hooked_attention":
111146
# assumes only one model at a time
112147
global _original_impl
113148
_original_impl = model.config._attn_implementation
114149

115-
AttentionInterface.register("ct_hooked_attention", ct_hooked_attention)
150+
AttentionInterface.register("ct_hooked_attention", _ct_hooked_attention)
116151
model.config._attn_implementation = "ct_hooked_attention"
117152

118153
impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
@@ -125,10 +160,15 @@ def initialize_hooked_attention(
125160
# ----- hooks ----- #
126161

127162

128-
def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
163+
def register_query_hook(
164+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
165+
) -> RemovableHandle:
129166
"""
130-
Registers a forward pre-hook on `module.impl` that replaces the `query` argument
131-
with `hook(mod, query)` (handles both positional and keyword forms).
167+
Register a hook which takes post-rope query states as an argument and
168+
returns the modified query states or `None`
169+
170+
:param module: attention module to add hook to
171+
:param hook: query hook function
132172
"""
133173
impl = getattr(module, IMPL_ATTR)
134174

src/compressed_tensors/modeling/kvcache.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,46 @@
1515
import inspect
1616
from typing import Callable, Optional, Tuple
1717

18-
import torch
19-
import transformers
2018
from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
2119
from compressed_tensors.quantization.lifecycle.initialize import (
2220
_initialize_scale_zero_point,
2321
)
2422
from compressed_tensors.utils import getattr_chain
2523
from compressed_tensors.utils.internal import InternalModule
26-
from packaging import version
2724
from torch import Tensor
25+
from torch.nn import Module
2826
from torch.utils.hooks import RemovableHandle
2927
from transformers import Cache, PreTrainedModel
3028

3129

32-
__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"]
30+
__all__ = [
31+
"QuantizedKVCache",
32+
"initialize_hooked_kv_cache",
33+
"register_key_hook",
34+
"register_value_hook",
35+
]
3336

3437

3538
KV_CACHE_ATTR = "kv_cache"
3639

3740

3841
class QuantizedKVCache(InternalModule):
39-
def __init__(self, attn_module: torch.nn.Module):
42+
"""
43+
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
44+
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
45+
hooked to trigger transforms and calibration hooks.
46+
47+
This module works by being registered as a submodule to attention modules via
48+
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
49+
kwargs with this module. This module adopts the functionality of the replaced cache,
50+
preserving caching functionality such as sliding window attention, ect.
51+
52+
:param attn_module: parent attention module
53+
"""
54+
55+
def __init__(self, attn_module: Module):
4056
super().__init__()
41-
self.attn_module_container = [attn_module] # avoid nn.Module circular reference
57+
self.attn_module_container = [attn_module] # avoid circular reference
4258
self.past_key_values: Optional[Cache] = None
4359
self._qparams_initialized = False
4460

@@ -70,13 +86,19 @@ def forward(
7086
self.past_key_values = None
7187
return ret
7288

73-
def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
89+
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
90+
"""
91+
Initialize kv cache quantization parameters if they have not already been
92+
intialized
93+
94+
:param model: parent model of attention module
95+
:param module: attention module to initialize with
96+
"""
7497
assert module is self.attn_module_container[0]
7598
scheme = getattr(module, "quantization_scheme", None)
7699
quant_args = getattr(scheme, "input_activations", None)
77100

78101
if not self._qparams_initialized and quant_args is not None:
79-
# TODO: use model.config.num_key_value_heads to find key_size, value_size
80102
assert quant_args.strategy == QuantizationStrategy.TENSOR
81103
_initialize_scale_zero_point(module, "k", quant_args)
82104
_initialize_scale_zero_point(module, "v", quant_args)
@@ -86,19 +108,7 @@ def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Modul
86108
# ----- initialize ----- #
87109

88110

89-
def initialize_hooked_kv_cache(
90-
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False
91-
):
92-
if not hasattr(module, KV_CACHE_ATTR):
93-
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module))
94-
module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True)
95-
96-
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
97-
if quantize:
98-
kv_cache.initialize_qparams_once(model, module)
99-
100-
101-
def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs):
111+
def _kv_cache_attention_hook(module: Module, args, kwargs):
102112
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
103113
_past_kv_name = (
104114
"past_key_values" # transformers#39956
@@ -111,10 +121,38 @@ def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs):
111121
return args, kwargs
112122

113123

124+
def initialize_hooked_kv_cache(
125+
model: PreTrainedModel, module: Module, quantize: bool = False
126+
):
127+
"""
128+
Initialize a `QuantizedKVCache` instance attached to attention
129+
130+
:param model: parent model of attention module
131+
:param module: attention module to initialize with
132+
:param quantize: initialize kv cache quantization parameters
133+
"""
134+
if not hasattr(module, KV_CACHE_ATTR):
135+
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module))
136+
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
137+
138+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
139+
if quantize:
140+
kv_cache.initialize_qparams_once(model, module)
141+
142+
114143
# ----- hooks ----- #
115144

116145

117-
def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
146+
def register_key_hook(
147+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
148+
) -> RemovableHandle:
149+
"""
150+
Register a hook which takes post-rope key states as an argument and
151+
returns the modified key states or `None`
152+
153+
:param module: attention module to add hook to
154+
:param hook: key hook function
155+
"""
118156
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
119157

120158
def _hook(cache: QuantizedKVCache, args, kwargs):
@@ -128,7 +166,16 @@ def _hook(cache: QuantizedKVCache, args, kwargs):
128166
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
129167

130168

131-
def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
169+
def register_value_hook(
170+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
171+
) -> RemovableHandle:
172+
"""
173+
Register a hook which takes value states as an argument and
174+
returns the modified value states or `None`
175+
176+
:param module: attention module to add hook to
177+
:param hook: value hook function
178+
"""
132179
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
133180

134181
def _hook(cache: QuantizedKVCache, args, kwargs):

src/compressed_tensors/transform/transform_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class TransformLocation(str, Enum):
4646
Q_ATTN = "q_attn"
4747

4848
def is_online(self) -> bool:
49+
"""
50+
Returns True if the transform location is online
51+
(applied at runtime), False otherwise
52+
"""
4953
return self not in (
5054
TransformLocation.WEIGHT_INPUT,
5155
TransformLocation.WEIGHT_OUTPUT,

0 commit comments

Comments
 (0)