From e224a5d2bbc53635e20df8fcd503293f0a5b2c4c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Oct 2025 18:32:32 -0400 Subject: [PATCH] add kv and attention Signed-off-by: Kyle Sayers --- src/compressed_tensors/modeling/attention.py | 186 +++++++++++++++++ src/compressed_tensors/modeling/kvcache.py | 192 ++++++++++++++++++ .../transform/factory/base.py | 37 +++- .../transform/factory/hadamard.py | 1 - .../transform/factory/matrix_multiply.py | 1 - .../transform/transform_args.py | 15 +- .../transform/utils/matrix.py | 34 ++-- tests/test_transform/conftest.py | 22 +- .../factory/test_correctness.py | 40 +++- 9 files changed, 495 insertions(+), 33 deletions(-) create mode 100644 src/compressed_tensors/modeling/attention.py create mode 100644 src/compressed_tensors/modeling/kvcache.py diff --git a/src/compressed_tensors/modeling/attention.py b/src/compressed_tensors/modeling/attention.py new file mode 100644 index 000000000..fd7a2c777 --- /dev/null +++ b/src/compressed_tensors/modeling/attention.py @@ -0,0 +1,186 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional +from weakref import ref + +from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + forward_quantize, +) +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, +) +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle +from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + +__all__ = [ + "QuantizedAttentionImpl", + "initialize_hooked_attention", + "register_query_hook", +] + + +IMPL_ATTR = "impl" +HOOKED_ATTENTION_NAME = "ct_hooked_attention" + + +class QuantizedAttentionImpl(InternalModule): + """ + QuantizedAttentionImpl module which wraps the functionality of the original + attention implementation. Unlike the original attention function, this + implementation is a `torch.nn.Module` which can be hooked to trigger + transforms and calibration hooks. + + This module works by being registered as a submodule to attention modules via + `initialize_hooked_attention`, registering a new attention implementation function + which calls this module, then setting the model attention implementation to the new + function. After triggering hooks and quantization, this module calls the original + attention implementation function. + + :param attn_module: parent attention module + """ + + def __init__(self, config: PretrainedConfig, attn_module: Module): + super().__init__() + self.config = config + self.attn_module = ref(attn_module) # avoid circular references + self._qparams_initialized = False + + def forward( + self, + module: Module, + query: Tensor, + key: Tensor, + value: Tensor, + *args, + **kwargs, + ): + # quantization + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled and self._qparams_initialized: + query = forward_quantize(module, query, "q", quant_args) + + # original attention + return ALL_ATTENTION_FUNCTIONS[_original_impl]( + module, + query, + key, + value, + *args, + **kwargs, + ) + + def initialize_qparams_once(self, model: PreTrainedModel, module: Module): + """ + Initialize attention quantization parameters if they have not already been + initialized. KV cache quantization parameters are initialized by the + `QuantizedKVCache` + + :param model: parent model of attention module + :param module: attention module to initialize with + """ + # TODO: move to initialize.py + assert module is self.attn_module() + scheme: Optional[QuantizationScheme] = getattr( + module, "quantization_scheme", None + ) + quant_args: Optional[QuantizationArgs] = getattr( + scheme, "input_activations", None + ) + + if ( + not self._qparams_initialized + and quant_args is not None + and not scheme.kv_cache_only + ): + assert quant_args.strategy == QuantizationStrategy.TENSOR + _initialize_scale_zero_point(module, "q", quant_args) + self._qparams_initialized = True + + +# ----- initialize ----- # + + +def _ct_hooked_attention(module: Module, *args, **kwargs): + if hasattr(module, IMPL_ATTR): + return module.impl(module, *args, **kwargs) + else: + return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) + + +def initialize_hooked_attention( + model: PreTrainedModel, module: Module, quantize: bool = True +): + """ + Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances + attached to attention + + :param model: parent model of attention module + :param module: attention module to initialize with + :param quantize: initialize attention quantization parameters + """ + if not hasattr(module, IMPL_ATTR): + module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module)) + if model.config._attn_implementation != HOOKED_ATTENTION_NAME: + # assumes only one model at a time + global _original_impl + _original_impl = model.config._attn_implementation + + AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention) + model.config._attn_implementation = HOOKED_ATTENTION_NAME + + impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR) + if quantize: + impl.initialize_qparams_once(model, module) + + initialize_hooked_kv_cache(model, module, quantize=quantize) + + +# ----- hooks ----- # + + +def register_query_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes post-rope query states as an argument and + returns the modified query states or `None` + + :param module: attention module to add hook to + :param hook: query hook function + """ + impl = getattr(module, IMPL_ATTR) + + def _hook(impl: QuantizedAttentionImpl, args, kwargs): + bound = inspect.signature(impl.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["query"]) + if value is not None: + bound.arguments["query"] = value + + return bound.args, bound.kwargs + + return impl.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/modeling/kvcache.py b/src/compressed_tensors/modeling/kvcache.py new file mode 100644 index 000000000..a1f04882e --- /dev/null +++ b/src/compressed_tensors/modeling/kvcache.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Optional, Tuple +from weakref import ref + +from compressed_tensors.quantization import QuantizationStrategy, forward_quantize +from compressed_tensors.quantization.lifecycle.initialize import ( + _initialize_scale_zero_point, +) +from compressed_tensors.utils import getattr_chain +from compressed_tensors.utils.internal import InternalModule +from torch import Tensor +from torch.nn import Module +from torch.utils.hooks import RemovableHandle +from transformers import Cache, PretrainedConfig, PreTrainedModel + + +__all__ = [ + "QuantizedKVCache", + "initialize_hooked_kv_cache", + "register_key_hook", + "register_value_hook", +] + + +KV_CACHE_ATTR = "kv_cache" + + +class QuantizedKVCache(InternalModule): + """ + QuantizedKVCache module which wraps the functionality of any existing kvcache args. + Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be + hooked to trigger transforms and calibration hooks. + + This module works by being registered as a submodule to attention modules via + `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values` + kwargs with this module. This module adopts the functionality of the replaced cache, + preserving caching functionality such as sliding window attention, ect. + + :param attn_module: parent attention module + """ + + def __init__(self, config: PretrainedConfig, attn_module: Module): + super().__init__() + self.config = config + self.attn_module = ref(attn_module) # avoid circular reference + self.past_key_values: Optional[Cache] = None + self._qparams_initialized = False + + def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]: + return self(*args, **kwargs) + + def forward( + self, + key_states: Tensor, + value_states: Tensor, + *args, + **kwargs, + ) -> Tuple[Tensor, Tensor]: + # quantization + module = self.attn_module() + quant_args_attr = "quantization_scheme.input_activations" + quant_args = getattr_chain(module, quant_args_attr, None) + quant_enabled = getattr(module, "quantization_enabled", True) + if quant_args is not None and quant_enabled and self._qparams_initialized: + key_states = forward_quantize(module, key_states, "k", quant_args) + value_states = forward_quantize(module, value_states, "v", quant_args) + + # original cache + if self.past_key_values is not None: + ret = self.past_key_values.update(key_states, value_states, *args, **kwargs) + else: + ret = (key_states, value_states) + + self.past_key_values = None + return ret + + def initialize_qparams_once(self, model: PreTrainedModel, module: Module): + """ + Initialize kv cache quantization parameters if they have not already been + initialized + + :param model: parent model of attention module + :param module: attention module to initialize with + """ + # TODO: move to initialize.py + assert module is self.attn_module() + scheme = getattr(module, "quantization_scheme", None) + quant_args = getattr(scheme, "input_activations", None) + + if not self._qparams_initialized and quant_args is not None: + assert quant_args.strategy == QuantizationStrategy.TENSOR + _initialize_scale_zero_point(module, "k", quant_args) + _initialize_scale_zero_point(module, "v", quant_args) + self._qparams_initialized = True + + +# ----- initialize ----- # + + +def _kv_cache_attention_hook(module: Module, args, kwargs): + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + _past_kv_name = ( + "past_key_values" # transformers#39956 + if "past_key_values" in inspect.signature(module.forward).parameters + else "past_key_value" + ) + kv_cache.past_key_values = kwargs.get(_past_kv_name, None) + kwargs[_past_kv_name] = kv_cache + + return args, kwargs + + +def initialize_hooked_kv_cache( + model: PreTrainedModel, module: Module, quantize: bool = False +): + """ + Initialize a `QuantizedKVCache` instance attached to attention + + :param model: parent model of attention module + :param module: attention module to initialize with + :param quantize: initialize kv cache quantization parameters + """ + if not hasattr(module, KV_CACHE_ATTR): + module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module)) + module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True) + + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + if quantize: + kv_cache.initialize_qparams_once(model, module) + + +# ----- hooks ----- # + + +def register_key_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes post-rope key states as an argument and + returns the modified key states or `None` + + :param module: attention module to add hook to + :param hook: key hook function + """ + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["key_states"]) + if value is not None: + bound.arguments["key_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) + + +def register_value_hook( + module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]] +) -> RemovableHandle: + """ + Register a hook which takes value states as an argument and + returns the modified value states or `None` + + :param module: attention module to add hook to + :param hook: value hook function + """ + kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR) + + def _hook(cache: QuantizedKVCache, args, kwargs): + bound = inspect.signature(cache.forward).bind(*args, **kwargs) + value = hook(module, bound.arguments["value_states"]) + if value is not None: + bound.arguments["value_states"] = value + + return bound.args, bound.kwargs + + return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 34d609e74..94dc2dac4 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -18,6 +18,14 @@ import torch import torch.nn.utils.parametrize as P import tqdm +from compressed_tensors.modeling.attention import ( + initialize_hooked_attention, + register_query_hook, +) +from compressed_tensors.modeling.kvcache import ( + initialize_hooked_kv_cache, + register_key_hook, +) from compressed_tensors.registry.registry import RegistryMixin, T from compressed_tensors.transform import ( TransformArgs, @@ -36,6 +44,7 @@ from compressed_tensors.utils.internal import InternalModule from torch import Tensor from torch.nn import Module, Parameter +from transformers import PreTrainedModel __all__ = ["TransformFactory", "TransformBase"] @@ -97,12 +106,13 @@ def apply_to_model(self, model: Module, use_tqdm=True): desc = f"Applying {self.name} transforms" for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)): - self._apply_to_module(module, arg) + self._apply_to_module(model, module, arg) - def _apply_to_module(self, module: Module, args: TransformArgs): + def _apply_to_module(self, model: Module, module: Module, args: TransformArgs): """ Create transforms and apply them to the module + :param model: model which module belongs to :param module: target module to apply transforms to :param args: defines how the transform will be applied to the target module """ @@ -156,7 +166,28 @@ def output_hook(_, _input, output): module.register_forward_hook(output_hook) - # other locations such as q_attn and k_attn have not been implemented + # register query hook to attention + elif args.location == TransformLocation.Q_ATTN: + if not isinstance(model, PreTrainedModel): + raise ValueError(f"Cannot hook attention of model: {model}") + + def query_hook(_, query_states): + return transform(query_states) + + initialize_hooked_attention(model, module, quantize=False) + register_query_hook(module, query_hook) + + # register key hook to kvcache + elif args.location == TransformLocation.K_CACHE: + if not isinstance(model, PreTrainedModel): + raise ValueError(f"Cannot hook attention of model: {model}") + + def key_hook(_, key_states): + return transform(key_states) + + initialize_hooked_kv_cache(model, module, quantize=False) + register_key_hook(module, key_hook) + else: raise NotImplementedError() diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index a843e2728..3b78dd25e 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -51,7 +51,6 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) exec_device = get_execution_device(module) device = get_offloaded_device(module) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 847034d5c..6d3920f97 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -50,7 +50,6 @@ def create_transform(self, module: Module, args: TransformArgs): :param module: parent module that transform will be applied to :param args: defines how the transform will be applied to the module """ - assert hasattr(module, "weight") size = get_transform_size(module, args.location, self.scheme.head_dim) device = get_offloaded_device(module) precision = self.scheme.precision if args.is_online() else torch.float64 diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index d3f469579..3967d4616 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -45,6 +45,16 @@ class TransformLocation(str, Enum): K_CACHE = "k_cache" Q_ATTN = "q_attn" + def is_online(self) -> bool: + """ + Returns True if the transform location is online + (applied at runtime), False otherwise + """ + return self not in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ) + class TransformArgs(BaseModel, use_enum_values=True): """ @@ -70,9 +80,6 @@ def wrap_singleton(cls, value): return value def is_online(self) -> bool: - return self.location not in ( - TransformLocation.WEIGHT_INPUT, - TransformLocation.WEIGHT_OUTPUT, - ) + return TransformLocation(self.location).is_online() model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/transform/utils/matrix.py b/src/compressed_tensors/transform/utils/matrix.py index 920728571..0414e3f69 100644 --- a/src/compressed_tensors/transform/utils/matrix.py +++ b/src/compressed_tensors/transform/utils/matrix.py @@ -34,6 +34,8 @@ def get_transform_size( :param head_dim: size of head when transform is applied to mha :return: size of matrix """ + size = None + if isinstance(module, torch.nn.Linear): if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT): size = module.in_features @@ -44,11 +46,13 @@ def get_transform_size( size = module.num_embeddings else: size = module.embedding_dim - else: - raise NotImplementedError(f"Transforms on {type(module)} are not supported") + elif head_dim is None: + raise NotImplementedError( + f"Transforms on {type(module)} are not supported without head_dim" + ) if head_dim is not None: - if size % head_dim != 0: + if size is not None and size % head_dim != 0: raise ValueError( f"{head_dim} must divide {size} for {type(module)} at {location}" ) @@ -105,11 +109,11 @@ def apply_transform_weight( assert transform_weight.shape[0] == transform_weight.shape[1] - if module_type == torch.nn.Linear: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) + if TransformLocation(location).is_online(): + return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.WEIGHT_INPUT: + if module_type == torch.nn.Linear: + if location == TransformLocation.WEIGHT_INPUT: # equivalent to (transform_weight @ value.T).T return _multihead_matmul(value, transform_weight.T) @@ -117,26 +121,14 @@ def apply_transform_weight( # equivalent to (value.T @ transform_weight).T return _multihead_matmul(transform_weight.T, value) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - # similar derivation to torch.nn.Linear, but `y = (x W)` elif module_type == torch.nn.Embedding: - if location == TransformLocation.INPUT: - return _multihead_matmul(value, transform_weight) - - elif location == TransformLocation.WEIGHT_INPUT: - return _multihead_matmul( - transform_weight, - value, - ) + if location == TransformLocation.WEIGHT_INPUT: + return _multihead_matmul(transform_weight, value) elif location == TransformLocation.WEIGHT_OUTPUT: return _multihead_matmul(value, transform_weight) - elif location == TransformLocation.OUTPUT: - return _multihead_matmul(value, transform_weight) - raise NotImplementedError( f"Applying transforms to {module_type} {location} is not supported" ) diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py index 824c06bd3..0ab5093c6 100644 --- a/tests/test_transform/conftest.py +++ b/tests/test_transform/conftest.py @@ -62,7 +62,9 @@ def __init__( num_attention_heads * self.head_dim, hidden_size, bias=False ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, past_key_values=None + ) -> torch.Tensor: batch_size, seq_len, hidden_size = hidden_states.shape hidden_shape = (batch_size, seq_len, -1, self.head_dim) @@ -70,6 +72,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if past_key_values is not None: + past_key_values.update(key_states, value_states, 0, {}) + key_states = self.repeat_kv(key_states, self.num_key_value_groups) value_states = self.repeat_kv(value_states, self.num_key_value_groups) @@ -97,6 +102,21 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +class MockAttentionModel(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, hidden_size, num_attention_heads, num_key_value_heads): + super().__init__(PretrainedConfig()) + self.self_attn = MockAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + ) + + def forward(self, x): + return self.self_attn(x) + + @pytest.fixture(scope="function") def model_apply(): model = TransformableModel(2, 4, 8, 16, 32, 64) diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index e95ba8de5..9598f6dc2 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -22,7 +22,7 @@ apply_transform_config, ) from compressed_tensors.utils import offloaded_dispatch -from tests.test_transform.conftest import MockAttention +from tests.test_transform.conftest import MockAttention, MockAttentionModel from tests.testing_utils import requires_accelerate, requires_gpu @@ -147,7 +147,7 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size config = TransformConfig( config_groups={ - "": TransformScheme( + "R2": TransformScheme( type=type, randomize=randomize, head_dim=head_dim, @@ -164,3 +164,39 @@ def test_correctness_attention_heads(type, randomize, head_dim, input_batch_size output = attention(input) assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) + + +@pytest.mark.parametrize("type", ("hadamard", "random-hadamard")) +@pytest.mark.parametrize("randomize", (True, False)) +@pytest.mark.parametrize("head_dim", (4, 8)) +@pytest.mark.parametrize("input_batch_size", (1, 5, 17)) +def test_correctness_query_key_locations(type, randomize, head_dim, input_batch_size): + hidden_size = 64 + num_attention_heads = 8 + + model = MockAttentionModel( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=head_dim, + ) + + input = torch.rand(input_batch_size, 5, hidden_size) + true_output = model(input) + + config = TransformConfig( + config_groups={ + "R3": TransformScheme( + type=type, + randomize=randomize, + head_dim=head_dim, + apply=[ + TransformArgs(targets="self_attn", location="q_attn"), + TransformArgs(targets="self_attn", location="k_cache"), + ], + ) + } + ) + apply_transform_config(model, config) + + output = model(input) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)