Skip to content

Commit 10ace39

Browse files
Merge pull request #13 from bigcode-project/dev
Dev
2 parents 32cdafd + 11720f3 commit 10ace39

File tree

2 files changed

+579
-0
lines changed

2 files changed

+579
-0
lines changed

c2c_search_eval.ipynb

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from typing import Dict\n",
10+
"import torch\n",
11+
"from src import datasets_loader\n",
12+
"from src.utils import retrieval_eval, pool_and_normalize\n",
13+
"from src.constants import GFG_DATA_PATH\n",
14+
"from transformers import AutoModel, AutoTokenizer\n",
15+
"from src.datasets_loader import prepare_tokenizer\n",
16+
"from src.preprocessing_utils import truncate_sentences\n",
17+
"from abc import ABC, abstractmethod\n"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": 2,
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"DEVICE = \"cuda:0\"\n",
27+
"\n",
28+
"EVAL_CONFIGS =[\n",
29+
" {\"model_path\": \"starencoder\", \"maximum_raw_length\": 10000, \"maximum_input_length\": 1024, \"device\": DEVICE},\n",
30+
" {\"model_path\": \"codebert\", \"maximum_raw_length\": 10000, \"maximum_input_length\": 512, \"device\": DEVICE}\n",
31+
"]"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": 3,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"def set_device(inputs: Dict[str, torch.Tensor], device: str) -> Dict[str, torch.Tensor]:\n",
41+
" output_data = {}\n",
42+
" for k, v in inputs.items():\n",
43+
" output_data[k] = v.to(device)\n",
44+
"\n",
45+
" return output_data\n",
46+
"\n",
47+
"\n",
48+
"def get_dataset(maximum_raw_length):\n",
49+
" test_data = datasets_loader.get_dataset( # Geeks4Geeks data\n",
50+
" dataset_name=\"gfg\",\n",
51+
" path_to_cache=GFG_DATA_PATH,\n",
52+
" split=\"test\",\n",
53+
" maximum_raw_length=maximum_raw_length,\n",
54+
" )\n",
55+
"\n",
56+
" return test_data\n",
57+
"\n",
58+
"\n",
59+
"class BaseEncoder(torch.nn.Module, ABC):\n",
60+
" def __init__(self, device, max_input_len, maximum_token_len, model_name):\n",
61+
" super().__init__()\n",
62+
"\n",
63+
" self.model_name = model_name\n",
64+
" self.tokenizer = prepare_tokenizer(model_name)\n",
65+
" self.encoder = (\n",
66+
" AutoModel.from_pretrained(model_name, use_auth_token=True).to(DEVICE).eval()\n",
67+
" )\n",
68+
" self.device = device\n",
69+
" self.max_input_len = max_input_len\n",
70+
" self.maximum_token_len = maximum_token_len\n",
71+
"\n",
72+
" @abstractmethod\n",
73+
" def forward(\n",
74+
" self,\n",
75+
" ):\n",
76+
" pass\n",
77+
"\n",
78+
" def encode(self, input_sentences, batch_size=32, **kwargs):\n",
79+
" truncated_input_sentences = truncate_sentences(\n",
80+
" input_sentences, self.max_input_len\n",
81+
" )\n",
82+
"\n",
83+
" n_batches = len(truncated_input_sentences) // batch_size + int(\n",
84+
" len(truncated_input_sentences) % batch_size > 0\n",
85+
" )\n",
86+
"\n",
87+
" embedding_batch_list = []\n",
88+
"\n",
89+
" for i in range(n_batches):\n",
90+
" start_idx = i * batch_size\n",
91+
" end_idx = min((i + 1) * batch_size, len(truncated_input_sentences))\n",
92+
"\n",
93+
" with torch.no_grad():\n",
94+
" embedding_batch_list.append(\n",
95+
" self.forward(truncated_input_sentences[start_idx:end_idx])\n",
96+
" .detach()\n",
97+
" .cpu()\n",
98+
" )\n",
99+
"\n",
100+
" return torch.cat(embedding_batch_list)\n",
101+
"\n",
102+
"\n",
103+
"class StarEncoder(BaseEncoder):\n",
104+
" def __init__(self, device, max_input_len, maximum_token_len):\n",
105+
" super().__init__(\n",
106+
" device,\n",
107+
" max_input_len,\n",
108+
" maximum_token_len,\n",
109+
" model_name=\"bigcode/starencoder\",\n",
110+
" )\n",
111+
"\n",
112+
" def forward(self, input_sentences):\n",
113+
" inputs = self.tokenizer(\n",
114+
" [\n",
115+
" f\"{self.tokenizer.cls_token}{sentence}{self.tokenizer.sep_token}\"\n",
116+
" for sentence in input_sentences\n",
117+
" ],\n",
118+
" padding=\"longest\",\n",
119+
" max_length=self.maximum_token_len,\n",
120+
" truncation=True,\n",
121+
" return_tensors=\"pt\",\n",
122+
" )\n",
123+
"\n",
124+
" outputs = self.encoder(**set_device(inputs, self.device))\n",
125+
" embedding = pool_and_normalize(outputs.hidden_states[-1], inputs.attention_mask)\n",
126+
"\n",
127+
" return embedding\n",
128+
"\n",
129+
"\n",
130+
"class CodeBERT(BaseEncoder):\n",
131+
" def __init__(self, device, max_input_len, maximum_token_len):\n",
132+
" super().__init__(\n",
133+
" device,\n",
134+
" max_input_len,\n",
135+
" maximum_token_len,\n",
136+
" model_name=\"microsoft/codebert-base\",\n",
137+
" )\n",
138+
"\n",
139+
" self.tokenizer = AutoTokenizer.from_pretrained(\"microsoft/codebert-base\")\n",
140+
"\n",
141+
" def forward(self, input_sentences):\n",
142+
" inputs = self.tokenizer(\n",
143+
" [sentence for sentence in input_sentences],\n",
144+
" padding=\"longest\",\n",
145+
" max_length=self.maximum_token_len,\n",
146+
" truncation=True,\n",
147+
" return_tensors=\"pt\",\n",
148+
" )\n",
149+
"\n",
150+
" inputs = set_device(inputs, self.device)\n",
151+
"\n",
152+
" outputs = self.encoder(inputs[\"input_ids\"], inputs[\"attention_mask\"])\n",
153+
"\n",
154+
" embedding = outputs[\"pooler_output\"]\n",
155+
"\n",
156+
" return torch.cat([torch.Tensor(el)[None, :] for el in embedding])\n",
157+
"\n",
158+
"\n",
159+
"def evaluate(model_path, maximum_raw_length, maximum_input_length, device):\n",
160+
" if \"starencoder\" in model_path.lower():\n",
161+
" model = StarEncoder(\n",
162+
" device, maximum_raw_length, maximum_input_length\n",
163+
" )\n",
164+
" elif \"codebert\" in model_path.lower():\n",
165+
" model = CodeBERT(\n",
166+
" device, maximum_raw_length, maximum_input_length\n",
167+
" )\n",
168+
" else:\n",
169+
" raise ValueError(\n",
170+
" \"Unsupported model type. We currently support starencoder and codebert.\"\n",
171+
" )\n",
172+
"\n",
173+
" model = model.to(device)\n",
174+
" model.eval()\n",
175+
"\n",
176+
" test_data = get_dataset(maximum_raw_length)\n",
177+
"\n",
178+
" source_entries, target_entries = [], []\n",
179+
" for source, target in test_data:\n",
180+
" source_entries.append(source)\n",
181+
" target_entries.append(target)\n",
182+
"\n",
183+
" source_embeddings = model.encode(source_entries)\n",
184+
" target_embeddings = model.encode(target_entries)\n",
185+
"\n",
186+
" recall_at_1, recall_at_5, mean_reciprocal_rank = retrieval_eval(\n",
187+
" source_embeddings, target_embeddings\n",
188+
" )\n",
189+
"\n",
190+
" print(\n",
191+
" f\"\\n{model_path}: R@1: {recall_at_1.item()}, R@5: {recall_at_5.item()}, MRR: {mean_reciprocal_rank.item()}\"\n",
192+
" )"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": 4,
198+
"metadata": {},
199+
"outputs": [
200+
{
201+
"name": "stderr",
202+
"output_type": "stream",
203+
"text": [
204+
"Using pad_token, but it is not set yet.\n",
205+
"Using sep_token, but it is not set yet.\n",
206+
"Using cls_token, but it is not set yet.\n",
207+
"Using mask_token, but it is not set yet.\n",
208+
"Some weights of the model checkpoint at bigcode/starencoder were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']\n",
209+
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
210+
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
211+
"Loading cached shuffled indices for dataset at /mnt/home/research-BertBigCode/resources/data/transcoder_evaluation_gfg/cache-e9f62aa12abed28d.arrow\n",
212+
"Loading cached processed dataset at /mnt/home/research-BertBigCode/resources/data/transcoder_evaluation_gfg/cache-62c8dbaa90db85ee_*_of_00096.arrow\n",
213+
"Loading cached shuffled indices for dataset at /mnt/home/research-BertBigCode/resources/data/transcoder_evaluation_gfg/cache-f652c1e33d8c1a14.arrow\n"
214+
]
215+
},
216+
{
217+
"name": "stdout",
218+
"output_type": "stream",
219+
"text": [
220+
"\n",
221+
"starencoder: R@1: 0.7222222089767456, R@5: 0.8767361044883728, MRR: 0.7930026054382324\n"
222+
]
223+
},
224+
{
225+
"name": "stderr",
226+
"output_type": "stream",
227+
"text": [
228+
"Loading cached shuffled indices for dataset at /mnt/home/research-BertBigCode/resources/data/transcoder_evaluation_gfg/cache-e9f62aa12abed28d.arrow\n",
229+
"Loading cached processed dataset at /mnt/home/research-BertBigCode/resources/data/transcoder_evaluation_gfg/cache-62c8dbaa90db85ee_*_of_00096.arrow\n",
230+
"Loading cached shuffled indices for dataset at /mnt/home/research-BertBigCode/resources/data/transcoder_evaluation_gfg/cache-f652c1e33d8c1a14.arrow\n"
231+
]
232+
},
233+
{
234+
"name": "stdout",
235+
"output_type": "stream",
236+
"text": [
237+
"\n",
238+
"codebert: R@1: 0.0052083334885537624, R@5: 0.02777777798473835, MRR: 0.025095948949456215\n"
239+
]
240+
}
241+
],
242+
"source": [
243+
"for eval_cfg in EVAL_CONFIGS:\n",
244+
" evaluate(**eval_cfg)"
245+
]
246+
}
247+
],
248+
"metadata": {
249+
"interpreter": {
250+
"hash": "ae635839a86c404533bb974203baf1bd26d9dc49bfbf145b45e9350c30045fdd"
251+
},
252+
"kernelspec": {
253+
"display_name": "Python 3.9.13 64-bit ('accelerate')",
254+
"language": "python",
255+
"name": "python3"
256+
},
257+
"language_info": {
258+
"codemirror_mode": {
259+
"name": "ipython",
260+
"version": 3
261+
},
262+
"file_extension": ".py",
263+
"mimetype": "text/x-python",
264+
"name": "python",
265+
"nbconvert_exporter": "python",
266+
"pygments_lexer": "ipython3",
267+
"version": "3.10.9"
268+
},
269+
"orig_nbformat": 4
270+
},
271+
"nbformat": 4,
272+
"nbformat_minor": 2
273+
}

0 commit comments

Comments
 (0)