Skip to content

Commit 395b54d

Browse files
committed
yl: refactor
1 parent ea655b1 commit 395b54d

19 files changed

+5825
-24
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,5 @@ cython_debug/
157157
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160-
#.idea/
160+
#.idea/
161+
notebook/test*

README.md

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,51 @@
11
# SnapKV :camera:
2-
We introduce an innovative and out-of-box KV cache compression method, SnapKV.
2+
We introduce an innovative and out-of-box KV cache compression method, [SnapKV](https://arxiv.org/abs/2404.14469).
33
## Requirements
4-
`transformers>=4.36`
4+
Currently tested with `transformers==4.37.0`, need to check if it is compatible with higher version.
5+
```
6+
transformers>=4.36
7+
flash-attn==2.4.0
8+
```
9+
## Installation
10+
```
11+
git clone [email protected]:FasterDecoding/SnapKV.git
12+
cd SnapKV
13+
pip install -e .
14+
```
515
## Quick Start
616
### Use SnapKV-optimized Models
7-
SnapKV-optimized models are all under models file, which could be directly imported and used the same like baseline models.
817
For example:
918
```python
10-
from models.modeling_mistral import MistralForCausalLM as SnapKVMistralForCausalLM
11-
model = SnapKVMistralForCausalLM.from_pretrained(
12-
model_name,
13-
torch_dtype=torch.float16,
14-
low_cpu_mem_usage=True,
15-
device_map="auto",
16-
use_flash_attention_2=True
17-
)
18-
tokenizer = transformers.AutoTokenizer.from_pretrained(
19-
model_name,
20-
padding_side="right",
21-
use_fast=False,
22-
)
19+
from snapkv.monkeypatch.monkeypatch import replace_mistral
20+
replace_mistral() # Use monkey patches enable SnapKV
2321
```
2422

23+
Check [the example notebook](./notebook/example.ipynb).
24+
2525
### Customize Your SnapKV-optimized Models
26-
SnapKV can be easily integrate with other models. You can follow the comment marked with `[SnapKV]` in [existing models](./models) to constrcut your own models. The detailed algorithm of SnapKV is in [snapkv_utils.py](./snapkv_utils.py)
26+
SnapKV can be easily integrate with other models.
2727

28+
You can follow the comment marked with `[SnapKV]` in [existing models](./snapkv/monkeypatch/monkeypatch.py) to construct your own models. (Currently we support [Llama family](./snapkv/monkeypatch/llama_hijack_4_37.py)/ [Mistral](./snapkv/monkeypatch//mistral_hijack_4_37.py)/ [Mixtral](./snapkv/monkeypatch//mixtral_hijack_4_37.py))
2829

29-
## Results
30-
![Comprehensive Experiment Results on LongBench](./figures/longbench.jpg)
31-
![Pressure Test Result on Needle-in-a-Haystack](./figures/LWM-Text-Chat-1M_SnapKV.jpg)
30+
The detailed algorithm of SnapKV is in [`snapkv_utils.py`](./snapkv/monkeypatch/snapkv_utils.py)
31+
32+
33+
## Partial Results
34+
![Comprehensive Experiment Results on LongBench](./assets/longbench.jpg)
35+
![Pressure Test Result on Needle-in-a-Haystack](./assets/LWM-Text-Chat-1M_SnapKV.jpg)
36+
37+
## TODO
38+
- [ ] Add observation experiments for reduplication.
39+
- [ ] Add LongBench for reduplication.
40+
- [ ] Explore the prompt phase compression.
41+
42+
## Citation
43+
If you feel this project is helpful, please consider cite our report :blush:
44+
```
45+
@article{li2024snapkv,
46+
title={SnapKV: LLM Knows What You are Looking for Before Generation},
47+
author={Li, Yuhong and Huang, Yingbing and Yang, Bowen and Venkitesh, Bharat and Locatelli, Acyr and Ye, Hanchen and Cai, Tianle and Lewis, Patrick and Chen, Deming},
48+
journal={arXiv preprint arXiv:2404.14469},
49+
year={2024}
50+
}
51+
```

notebook/example.ipynb

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import os\n",
10+
"# CUDAVISIBLE DEVICES\n",
11+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
12+
"import torch\n",
13+
"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig\n",
14+
"import transformers"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"from snapkv.monkeypatch.monkeypatch import replace_llama, replace_mistral, replace_mixtral"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"replace_mistral()"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": null,
38+
"metadata": {},
39+
"outputs": [],
40+
"source": [
41+
"from fastchat.model import load_model, get_conversation_template"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"model = AutoModelForCausalLM.from_pretrained(\n",
51+
" \"mistralai/Mistral-7B-Instruct-v0.2\",\n",
52+
" torch_dtype=torch.bfloat16,\n",
53+
" low_cpu_mem_usage=True,\n",
54+
" device_map=\"auto\",\n",
55+
" use_flash_attention_2=True\n",
56+
" )"
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": null,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.2\")"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": null,
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"with open('snapkv.txt', 'r') as f:\n",
75+
" content = f.read().strip()"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"question = \"\\n What is the repository of SnapKV?\""
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"metadata": {},
91+
"outputs": [],
92+
"source": [
93+
"conv = get_conversation_template(\"longchat\")\n",
94+
"conv.messages = []\n",
95+
"conv.append_message(conv.roles[0],content + question)\n",
96+
"# conv.append_message(conv.roles[0],\"Who is Kobe Bryant?\")\n",
97+
"conv.append_message(conv.roles[1], None)"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"prompt = conv.get_prompt()"
107+
]
108+
},
109+
{
110+
"cell_type": "code",
111+
"execution_count": null,
112+
"metadata": {},
113+
"outputs": [],
114+
"source": [
115+
"input_ids = tokenizer.encode(prompt, return_tensors='pt')"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {},
122+
"outputs": [],
123+
"source": [
124+
"input_ids_len = input_ids.size(1)\n",
125+
"print(input_ids_len)"
126+
]
127+
},
128+
{
129+
"cell_type": "code",
130+
"execution_count": null,
131+
"metadata": {},
132+
"outputs": [],
133+
"source": [
134+
"outputs = model.generate(input_ids.cuda(), max_new_tokens=200, do_sample=False)"
135+
]
136+
},
137+
{
138+
"cell_type": "code",
139+
"execution_count": null,
140+
"metadata": {},
141+
"outputs": [],
142+
"source": [
143+
"print(tokenizer.decode(outputs[0][input_ids_len:], skip_special_tokens=True))"
144+
]
145+
}
146+
],
147+
"metadata": {
148+
"kernelspec": {
149+
"display_name": "code_attn",
150+
"language": "python",
151+
"name": "python3"
152+
},
153+
"language_info": {
154+
"codemirror_mode": {
155+
"name": "ipython",
156+
"version": 3
157+
},
158+
"file_extension": ".py",
159+
"mimetype": "text/x-python",
160+
"name": "python",
161+
"nbconvert_exporter": "python",
162+
"pygments_lexer": "ipython3",
163+
"version": "3.11.0"
164+
}
165+
},
166+
"nbformat": 4,
167+
"nbformat_minor": 2
168+
}

0 commit comments

Comments
 (0)