Skip to content

Commit cbce8dc

Browse files
committed
yl: notebook->notebooks
1 parent ac2857b commit cbce8dc

File tree

3 files changed

+259
-0
lines changed

3 files changed

+259
-0
lines changed

notebooks/test_snapkv.ipynb

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stderr",
10+
"output_type": "stream",
11+
"text": [
12+
"/home/tianle/miniconda3/envs/code_attn/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13+
" from .autonotebook import tqdm as notebook_tqdm\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"import os\n",
19+
"# CUDAVISIBLE DEVICES\n",
20+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"5,6\"\n",
21+
"\n",
22+
"os.environ['HF_DATASETS_CACHE'] = \"/work/tianle/huggingface/datasets\"\n",
23+
"os.environ['HF_HOME'] = \"/work/tianle/huggingface\"\n",
24+
"\n",
25+
"import torch\n",
26+
"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig\n",
27+
"import transformers"
28+
]
29+
},
30+
{
31+
"cell_type": "code",
32+
"execution_count": 2,
33+
"metadata": {},
34+
"outputs": [],
35+
"source": [
36+
"from snapkv.monkeypatch.monkeypatch import replace_llama, replace_mistral, replace_mixtral"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 3,
42+
"metadata": {},
43+
"outputs": [
44+
{
45+
"name": "stderr",
46+
"output_type": "stream",
47+
"text": [
48+
"/work/data/tianle/share/SnapKV/snapkv/monkeypatch/monkeypatch.py:50: UserWarning: Transformers version 4.36.2 might not be compatible with SnapKV. SnapKV is tested with Transformers version ['4.37'].\n",
49+
" warnings.warn(f\"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}.\")\n"
50+
]
51+
}
52+
],
53+
"source": [
54+
"replace_mixtral()"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 4,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"from fastchat.model import load_model, get_conversation_template\n"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 5,
69+
"metadata": {},
70+
"outputs": [
71+
{
72+
"name": "stdout",
73+
"output_type": "stream",
74+
"text": [
75+
"Transformers version: 4.36.2\n"
76+
]
77+
}
78+
],
79+
"source": [
80+
"from importlib.metadata import version\n",
81+
"try:\n",
82+
" transformers_version = version(\"transformers\")\n",
83+
" print(f\"Transformers version: {transformers_version}\")\n",
84+
"except Exception as e:\n",
85+
" print(f\"Error: {e}\")\n"
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": 6,
91+
"metadata": {},
92+
"outputs": [
93+
{
94+
"name": "stderr",
95+
"output_type": "stream",
96+
"text": [
97+
"The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation=\"flash_attention_2\"` instead.\n",
98+
"Loading checkpoint shards: 100%|██████████| 19/19 [01:10<00:00, 3.71s/it]\n"
99+
]
100+
}
101+
],
102+
"source": [
103+
"model = AutoModelForCausalLM.from_pretrained(\n",
104+
" \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
105+
" torch_dtype=torch.bfloat16,\n",
106+
" low_cpu_mem_usage=True,\n",
107+
" device_map=\"auto\",\n",
108+
" use_flash_attention_2=True\n",
109+
" )"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": 7,
115+
"metadata": {},
116+
"outputs": [],
117+
"source": [
118+
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mixtral-8x7B-Instruct-v0.1\")"
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": 8,
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"# load './snapkv.txt'\n",
128+
"with open('snapkv.txt', 'r') as f:\n",
129+
" content = f.read().strip()"
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 9,
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"question = \"\\n What is the repository of SnapKV?\""
139+
]
140+
},
141+
{
142+
"cell_type": "code",
143+
"execution_count": 10,
144+
"metadata": {},
145+
"outputs": [],
146+
"source": [
147+
"conv = get_conversation_template(\"longchat\")\n",
148+
"conv.messages = []\n",
149+
"conv.append_message(conv.roles[0],content + question)\n",
150+
"# conv.append_message(conv.roles[0],\"Who is Kobe Bryant?\")\n",
151+
"conv.append_message(conv.roles[1], None)"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": 11,
157+
"metadata": {},
158+
"outputs": [],
159+
"source": [
160+
"prompt = conv.get_prompt()"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": 12,
166+
"metadata": {},
167+
"outputs": [],
168+
"source": [
169+
"input_ids = tokenizer.encode(prompt, return_tensors='pt')"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": 13,
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"input_ids_len = input_ids.size(1)"
179+
]
180+
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": 14,
184+
"metadata": {},
185+
"outputs": [
186+
{
187+
"name": "stderr",
188+
"output_type": "stream",
189+
"text": [
190+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
191+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
192+
]
193+
}
194+
],
195+
"source": [
196+
"outputs = model.generate(input_ids.cuda(), max_new_tokens=200, do_sample=False)"
197+
]
198+
},
199+
{
200+
"cell_type": "code",
201+
"execution_count": 15,
202+
"metadata": {},
203+
"outputs": [
204+
{
205+
"name": "stdout",
206+
"output_type": "stream",
207+
"text": [
208+
"The repository of SnapKV is available at <https://github.com/FasterDecoding/SnapKV>.\n"
209+
]
210+
}
211+
],
212+
"source": [
213+
"print(tokenizer.decode(outputs[0][input_ids_len:], skip_special_tokens=True))"
214+
]
215+
},
216+
{
217+
"cell_type": "code",
218+
"execution_count": null,
219+
"metadata": {},
220+
"outputs": [],
221+
"source": []
222+
},
223+
{
224+
"cell_type": "code",
225+
"execution_count": null,
226+
"metadata": {},
227+
"outputs": [],
228+
"source": []
229+
},
230+
{
231+
"cell_type": "code",
232+
"execution_count": null,
233+
"metadata": {},
234+
"outputs": [],
235+
"source": []
236+
}
237+
],
238+
"metadata": {
239+
"kernelspec": {
240+
"display_name": "code_attn",
241+
"language": "python",
242+
"name": "python3"
243+
},
244+
"language_info": {
245+
"codemirror_mode": {
246+
"name": "ipython",
247+
"version": 3
248+
},
249+
"file_extension": ".py",
250+
"mimetype": "text/x-python",
251+
"name": "python",
252+
"nbconvert_exporter": "python",
253+
"pygments_lexer": "ipython3",
254+
"version": "3.11.0"
255+
}
256+
},
257+
"nbformat": 4,
258+
"nbformat_minor": 2
259+
}

0 commit comments

Comments
 (0)