Skip to content

Commit 0360ebc

Browse files
committed
Support transformers v5
Signed-off-by: SimJeg <[email protected]>
1 parent 68f8ef8 commit 0360ebc

File tree

8 files changed

+30
-31
lines changed

8 files changed

+30
-31
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,9 @@ Below we report the average performance on the RULER dataset with 4k context len
170170
We support KV cache quantization through the transformers `QuantizedCache` class (see [HF blog post](https://huggingface.co/blog/kv-cache-quantization#how-to-use-quantized-kv-cache-in-%F0%9F%A4%97-transformers)). To use it, simply pass a cache object to your pipeline:
171171

172172
```python
173-
from transformers import QuantizedCacheConfig, QuantoQuantizedCache
173+
from transformers import QuantizedCache
174174

175-
config = QuantizedCacheConfig(nbits=4)
176-
cache = QuantoQuantizedCache(config)
175+
cache = QuantizedCache(backend="quanto", nbits=4)
177176

178177
pipe(..., cache=cache)
179178
```

kvpress/presses/kvzip_press.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,14 +363,13 @@ def compress_post(self, model: PreTrainedModel):
363363

364364
# calculate the pruned KV pairs across layers
365365
if self.layerwise:
366-
nl = int(num_key_value_heads * ctx_len * self.compression_ratio)
366+
nl = int(bsz * num_key_value_heads * ctx_len * self.compression_ratio)
367367
n_pruned_layers = nl * torch.ones(n_layer, device=self.score_val.device, dtype=torch.int)
368368
else:
369-
score_sort = torch.sort(self.score_val.reshape(-1)).values # ascending order
370-
n = max(int(len(score_sort) * self.compression_ratio) - 1, 0)
371-
thres = score_sort[n].item()
372-
373-
n_pruned_layers = (self.score_val.reshape(n_layer, -1) <= thres).sum(-1) # n_prune
369+
n_pruned_indices = int(self.score_val.numel() * self.compression_ratio)
370+
pruned_indices = torch.topk(-self.score_val.reshape(-1), n_pruned_indices).indices
371+
n_tokens_per_layer = bsz * num_key_value_heads * ctx_len
372+
n_pruned_layers = torch.bincount(pruned_indices // n_tokens_per_layer, minlength=n_layer).int()
374373

375374
for layer in model.model.layers:
376375
module = layer.self_attn

notebooks/speed_and_memory.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
},
2020
{
2121
"cell_type": "code",
22-
"execution_count": 2,
22+
"execution_count": null,
2323
"metadata": {},
2424
"outputs": [],
2525
"source": [
@@ -35,7 +35,7 @@
3535
"import numpy as np\n",
3636
"import torch\n",
3737
"from transformers import AutoModelForCausalLM, pipeline\n",
38-
"from transformers import QuantizedCacheConfig, QuantoQuantizedCache, DynamicCache, QuantizedCache\n",
38+
"from transformers import DynamicCache, QuantizedCache\n",
3939
"from transformers.utils.logging import disable_progress_bar\n",
4040
"import transformers\n",
4141
"\n",
@@ -65,19 +65,19 @@
6565
},
6666
{
6767
"cell_type": "code",
68-
"execution_count": 5,
68+
"execution_count": null,
6969
"metadata": {},
7070
"outputs": [],
7171
"source": [
7272
"def get_size_of_cache(cache):\n",
73-
" if isinstance(cache, QuantoQuantizedCache):\n",
73+
" if isinstance(cache, QuantizedCache):\n",
7474
" # We cannot use x.element_size() * x.nelement() as below to calculate the size of the cache, \n",
7575
" # as cache._quantized_value_cache[0].element_size() triggers a call of __torch_dispatch__,\n",
7676
" # which, in turn, unpacks the internally packed tensor; and thus does not report the correct internal storage size.\n",
7777
" # See also https://github.com/huggingface/optimum-quanto/blob/main/optimum/quanto/tensor/packed.py#L144\n",
7878
"\n",
79-
" # As QuantoQuantizedCache stores values, as well as shift and scale, \n",
80-
" # we temporarily save the cache to disc and getthe size of the saved object\n",
79+
" # As QuantizedCache stores values, as well as shift and scale, \n",
80+
" # we temporarily save the cache to disc and get the size of the saved object\n",
8181
" temp_file = \"tmp.pickle\"\n",
8282
" with open(temp_file, \"wb\") as f:\n",
8383
" pickle.dump(cache, f)\n",
@@ -100,7 +100,7 @@
100100
},
101101
{
102102
"cell_type": "code",
103-
"execution_count": 6,
103+
"execution_count": null,
104104
"metadata": {},
105105
"outputs": [],
106106
"source": [
@@ -125,7 +125,7 @@
125125
" if cache_implementation == \"dynamic\":\n",
126126
" cache = DynamicCache()\n",
127127
" elif cache_implementation == \"quantized\":\n",
128-
" cache = QuantoQuantizedCache(config=model.config, nbits=4)\n",
128+
" cache = QuantizedCache(backend=\"quanto\", config=model.config, nbits=4)\n",
129129
" else:\n",
130130
" raise NotImplementedError(f\"Cache {cache_implementation} not yet implemented\")\n",
131131
"\n",

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ readme = "README.md"
1313
dependencies = [
1414
"numpy>=2.0.0,<3",
1515
"torch>=2.3.1,<3",
16-
# transformers<4.54 is not supported due to refactoring of the transformers library.
17-
# transformers 4.54-4.55.2 are not compatible with kvpress due to flash attention bugs in transformers
18-
"transformers>=4.56,<5.0.0",
16+
"transformers>=5.0.0",
1917
"sentencepiece>=0.2.0,<0.3",
2018
"protobuf>=5.27.2,<6",
2119
"datasets>=2.21.0,<3",

tests/integration/test_ruler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import datasets
55
import pytest
66
import torch
7-
from transformers import DynamicCache, QuantoQuantizedCache
7+
from transformers import DynamicCache, QuantizedCache
88
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
99

1010
from kvpress import QFilterPress
@@ -44,7 +44,7 @@ def test_ruler_is_correct(
4444
if cache == "dynamic":
4545
cache = DynamicCache()
4646
elif cache == "quantized" and is_optimum_quanto_available():
47-
cache = QuantoQuantizedCache(config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4)
47+
cache = QuantizedCache(backend="quanto", config=kv_press_qwen3_flash_attn_pipeline.model.config, nbits=4)
4848
elif cache == "quantized" and not is_optimum_quanto_available():
4949
pytest.skip("Quanto is not installed")
5050
else:
@@ -89,7 +89,7 @@ def test_ruler_is_correct_for_qfilter(
8989
if cache == "dynamic":
9090
cache = DynamicCache()
9191
elif cache == "quantized" and is_optimum_quanto_available():
92-
cache = QuantoQuantizedCache(config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4)
92+
cache = QuantizedCache(backend="quanto", config=kv_press_llama3_2_flash_attn_pipeline.model.config, nbits=4)
9393
elif cache == "quantized" and not is_optimum_quanto_available():
9494
pytest.skip("Quanto is not installed")
9595
else:

tests/presses/test_key_rerotation_press_rope.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_rerotate_keys_is_matches_reference_implementation(
4848
"rope_type": "yarn",
4949
}
5050
cfg.max_position_embeddings = 131072
51+
cfg.rope_theta = 500000.0
5152
try:
5253
unit_test_model.model.rotary_emb = LlamaRotaryEmbedding(cfg, device=unit_test_model.device)
5354
except KeyError:
@@ -63,6 +64,8 @@ def test_rerotate_keys_is_matches_reference_implementation(
6364
unit_test_model = unit_test_model.cuda().half()
6465
elif precision == "half":
6566
pytest.skip("Half-precision test skipped because CUDA is not available.")
67+
elif precision == "full":
68+
unit_test_model = unit_test_model.float()
6669

6770
original_press = RandomPressStoreIndices(compression_ratio=0.5)
6871
key_rerotation_press = KeyRerotationPress(press=original_press)

tests/test_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import pytest
88
import torch
9-
from transformers import AutoTokenizer, DynamicCache, QuantoQuantizedCache
9+
from transformers import AutoTokenizer, DynamicCache, QuantizedCache
1010
from transformers.utils import is_flash_attn_2_available, is_optimum_quanto_available
1111

1212
from kvpress import ExpectedAttentionPress
@@ -112,7 +112,7 @@ def test_pipeline_with_quantized_cache(kv_press_danube_pipeline, caplog): # noq
112112
context = "This is a test article. It was written on 2022-01-01."
113113
questions = ["When was this article written?"]
114114
press = ExpectedAttentionPress(compression_ratio=0.4)
115-
cache = QuantoQuantizedCache(config=kv_press_danube_pipeline.model.config, nbits=4)
115+
cache = QuantizedCache(backend="quanto", config=kv_press_danube_pipeline.model.config, nbits=4)
116116
answers = kv_press_danube_pipeline(context, questions=questions, press=press, cache=cache)["answers"]
117117

118118
assert len(answers) == 1

tests/test_press_call.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def test_context_manager_applies_compression(unit_test_model): # noqa: F811
2828

2929
seq_len = input_ids.shape[-1]
3030

31-
for key, values in past_key_values:
32-
assert key.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
33-
assert values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
31+
for layer in past_key_values.layers:
32+
assert layer.keys.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
33+
assert layer.values.shape[2] == int(seq_len * 0.8) == past_key_values.get_seq_length()
3434

3535
input_ids = unit_test_model.dummy_inputs["input_ids"].to(unit_test_model.device)
3636
past_key_values = unit_test_model(input_ids, past_key_values=DynamicCache()).past_key_values
3737

38-
for key, values in past_key_values:
39-
assert key.shape[2] == seq_len == past_key_values.get_seq_length()
40-
assert values.shape[2] == seq_len == past_key_values.get_seq_length()
38+
for layer in past_key_values.layers:
39+
assert layer.keys.shape[2] == seq_len == past_key_values.get_seq_length()
40+
assert layer.values.shape[2] == seq_len == past_key_values.get_seq_length()

0 commit comments

Comments
 (0)