Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5a0ec31
quantization done
kylesayrs Aug 20, 2025
e7b1338
fix kv cache passing
kylesayrs Aug 20, 2025
36d7f2c
slightly cleaner, validated with r3
kylesayrs Aug 20, 2025
e64dbd2
qparam initialization
kylesayrs Aug 20, 2025
1a01dc3
add markers
kylesayrs Aug 20, 2025
585550b
cleanup
kylesayrs Aug 20, 2025
f49524a
add narrow match
kylesayrs Aug 20, 2025
7374312
better quant matching
kylesayrs Aug 20, 2025
dbc90ba
Merge remote-tracking branch 'origin' into attention-cache-submodules
kylesayrs Aug 20, 2025
8941779
attention and kv quantization
kylesayrs Aug 21, 2025
c4af508
remove debug prints
kylesayrs Aug 21, 2025
71463c7
add todo for other strategies (block/group, channel, head)
kylesayrs Aug 21, 2025
2cfff73
support registering to offloaded attention
kylesayrs Aug 25, 2025
53aa503
better merging and serialization logic
kylesayrs Aug 25, 2025
b47cda0
comment
kylesayrs Aug 25, 2025
c019e64
simplify hook replacement code
kylesayrs Aug 26, 2025
ceaf677
fix typo
kylesayrs Aug 26, 2025
29de3ec
do not attach scheme if not targeted
kylesayrs Aug 26, 2025
e904883
revert format changes
kylesayrs Aug 26, 2025
6f91dd6
fix typo
kylesayrs Aug 26, 2025
99d7143
Merge remote-tracking branch 'origin' into attention-cache-submodules
kylesayrs Aug 26, 2025
154d2e4
deprecate safe permute
kylesayrs Sep 11, 2025
ed8f5dc
meta hadamards
kylesayrs Sep 8, 2025
dfdbd3f
fix dynamic weights keys
kylesayrs Sep 8, 2025
7aec12c
break out _tie_offloaded_tensors, add test
kylesayrs Sep 8, 2025
2bf33c8
better comments
kylesayrs Sep 8, 2025
33b71b3
better comments
kylesayrs Sep 8, 2025
2ef1ab2
simplify function
kylesayrs Sep 11, 2025
a11770a
style
kylesayrs Sep 11, 2025
5438a81
better type hints, warn once
kylesayrs Sep 9, 2025
5875644
remove unneeded import
kylesayrs Sep 9, 2025
f179a91
allow-group-dynamic-quantization
kylesayrs Sep 8, 2025
f3d0e58
satisfy quality checker
kylesayrs Sep 8, 2025
8d35794
more clear
kylesayrs Sep 8, 2025
82ee671
basic support
kylesayrs Sep 9, 2025
c123637
fix merge
kylesayrs Sep 9, 2025
1016a75
ungate group activation quant
kylesayrs Sep 9, 2025
1c217e4
refactor
kylesayrs Sep 9, 2025
e5447f3
fix merge
kylesayrs Sep 11, 2025
6672617
reduce diff
kylesayrs Sep 11, 2025
199f274
activations have one row
kylesayrs Sep 12, 2025
d53ba36
cleanup, logging
kylesayrs Sep 17, 2025
f5390bd
Merge remote-tracking branch 'origin' into attention-cache-submodules
kylesayrs Sep 17, 2025
ab85d09
Merge branch 'kylesayrs/group-activation-quantization' into attention…
kylesayrs Sep 17, 2025
0d860cd
remove scheme merge
kylesayrs Sep 17, 2025
20744eb
use attention head quant
kylesayrs Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions src/compressed_tensors/modeling/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Optional

import torch
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
forward_quantize,
)
from compressed_tensors.quantization.lifecycle.initialize import (
_initialize_scale_zero_point,
)
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from torch.utils.hooks import RemovableHandle
from transformers import AttentionInterface, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS


__all__ = ["IMPL_ATTR", "QuantizedAttentionImpl"]


IMPL_ATTR = "impl"
_original_impl = "eager" # mutable


class QuantizedAttentionImpl(InternalModule):
def __init__(self, attn_module: torch.nn.Module):
super().__init__()
self.attn_module_container = [attn_module] # avoid circular reference
self._qparams_initialized = False

def forward(
self,
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
*args,
**kwargs,
):
# quantization
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled and self._qparams_initialized:
query = forward_quantize(module, query, "q", quant_args)

# original attention
return ALL_ATTENTION_FUNCTIONS[_original_impl](
module,
query,
key,
value,
*args,
**kwargs,
)

def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
assert module is self.attn_module_container[0]
scheme: Optional[QuantizationScheme] = getattr(
module, "quantization_scheme", None
)
quant_args: Optional[QuantizationArgs] = getattr(
scheme, "input_activations", None
)

if (
not self._qparams_initialized
and quant_args is not None
and not scheme.kv_cache_only
):
# TODO: use model.config.num_attention_heads to find query_size
assert quant_args.strategy in (
QuantizationStrategy.TENSOR,
QuantizationStrategy.TOKEN,
QuantizationStrategy.ATTN_HEAD,
)

num_heads = model.config.num_attention_heads
hidden_size = model.config.hidden_size
observed_dtype = next(module.parameters()).dtype
_initialize_scale_zero_point(
module,
"q",
quant_args,
observed_shape=(num_heads, hidden_size),
observed_dtype=observed_dtype,
force_zero_point=True,
)
self._qparams_initialized = True


# ----- initialize ----- #


def ct_hooked_attention(module: torch.nn.Module, *args, **kwargs):
if hasattr(module, IMPL_ATTR):
return module.impl(module, *args, **kwargs)
else:
return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs)


def initialize_hooked_attention(
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = True
):
if not hasattr(module, IMPL_ATTR):
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module))
if model.config._attn_implementation != "ct_hooked_attention":
# assumes only one model at a time
global _original_impl
_original_impl = model.config._attn_implementation

AttentionInterface.register("ct_hooked_attention", ct_hooked_attention)
model.config._attn_implementation = "ct_hooked_attention"

impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
if quantize:
impl.initialize_qparams_once(model, module)

initialize_hooked_kv_cache(model, module, quantize=quantize)


# ----- hooks ----- #


def register_query_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
"""
Registers a forward pre-hook on `module.impl` that replaces the `query` argument
with `hook(mod, query)` (handles both positional and keyword forms).
"""
impl = getattr(module, IMPL_ATTR)

def _hook(impl: QuantizedAttentionImpl, args, kwargs):
bound = inspect.signature(module.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["query"])
if value is not None:
bound.arguments["query"] = value

return bound.args, bound.kwargs

return impl.register_forward_pre_hook(_hook, with_kwargs=True)
163 changes: 163 additions & 0 deletions src/compressed_tensors/modeling/kvcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Callable, Optional, Tuple

import torch
import transformers
from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
from compressed_tensors.quantization.lifecycle.initialize import (
_initialize_scale_zero_point,
)
from compressed_tensors.utils import getattr_chain
from compressed_tensors.utils.internal import InternalModule
from packaging import version
from torch import Tensor
from torch.utils.hooks import RemovableHandle
from transformers import Cache, PreTrainedModel


__all__ = ["KV_CACHE_ATTR", "QuantizedKVCache"]


KV_CACHE_ATTR = "kv_cache"


class QuantizedKVCache(InternalModule):
def __init__(self, attn_module: torch.nn.Module):
super().__init__()
self.attn_module_container = [attn_module] # avoid nn.Module circular reference
self.past_key_values: Optional[Cache] = None
self._qparams_initialized = False

def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
return self(*args, **kwargs)

def forward(
self,
key_states: Tensor,
value_states: Tensor,
*args,
**kwargs,
) -> Tuple[Tensor, Tensor]:
# quantization
module = self.attn_module_container[0]
quant_args_attr = "quantization_scheme.input_activations"
quant_args = getattr_chain(module, quant_args_attr, None)
quant_enabled = getattr(module, "quantization_enabled", True)
if quant_args is not None and quant_enabled and self._qparams_initialized:
key_states = forward_quantize(module, key_states, "k", quant_args)
value_states = forward_quantize(module, value_states, "v", quant_args)

# original cache
if self.past_key_values is not None:
ret = self.past_key_values.update(key_states, value_states, *args, **kwargs)
else:
ret = (key_states, value_states)

self.past_key_values = None
return ret

def initialize_qparams_once(self, model: PreTrainedModel, module: torch.nn.Module):
assert module is self.attn_module_container[0]
scheme = getattr(module, "quantization_scheme", None)
quant_args = getattr(scheme, "input_activations", None)

if not self._qparams_initialized and quant_args is not None:
# TODO: use model.config.num_key_value_heads to find key_size, value_size
assert quant_args.strategy in (
QuantizationStrategy.TENSOR,
QuantizationStrategy.TOKEN,
QuantizationStrategy.ATTN_HEAD,
)
num_heads = model.config.num_key_value_heads
hidden_size = model.config.hidden_size
observed_dtype = next(module.parameters()).dtype
_initialize_scale_zero_point(
module,
"k",
quant_args,
observed_shape=(num_heads, hidden_size),
observed_dtype=observed_dtype,
force_zero_point=True,
)
_initialize_scale_zero_point(
module,
"v",
quant_args,
observed_shape=(num_heads, hidden_size),
observed_dtype=observed_dtype,
force_zero_point=True,
)
self._qparams_initialized = True


# ----- initialize ----- #


def initialize_hooked_kv_cache(
model: PreTrainedModel, module: torch.nn.Module, quantize: bool = False
):
if not hasattr(module, KV_CACHE_ATTR):
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module))
module.register_forward_pre_hook(kv_cache_attention_hook, with_kwargs=True)

kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
if quantize:
kv_cache.initialize_qparams_once(model, module)


def kv_cache_attention_hook(module: torch.nn.Module, args, kwargs):
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
_past_kv_name = (
"past_key_value"
if version.parse(transformers.__version__) <= version.parse("4.55.4")
else "past_key_values" # transformers#39956
)
kv_cache.past_key_values = kwargs.get(_past_kv_name, None)
kwargs[_past_kv_name] = kv_cache

return args, kwargs


# ----- hooks ----- #


def register_key_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["key_states"])
if value is not None:
bound.arguments["key_states"] = value

return bound.args, bound.kwargs

return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)


def register_value_hook(module: torch.nn.Module, hook: Callable) -> RemovableHandle:
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

def _hook(cache: QuantizedKVCache, args, kwargs):
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
value = hook(module, bound.arguments["value_states"])
if value is not None:
bound.arguments["value_states"] = value

return bound.args, bound.kwargs

return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
Loading