Skip to content

Commit a90f041

Browse files
committed
addressing PR comments
1 parent f7b2825 commit a90f041

File tree

2 files changed

+67
-82
lines changed

2 files changed

+67
-82
lines changed

notebooks/influence_sentiment_analysis.ipynb

Lines changed: 25 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,19 @@
8383
}
8484
],
8585
"source": [
86-
"from datasets import load_dataset\n",
87-
"import torch\n",
88-
"from sklearn.metrics import f1_score\n",
86+
"from copy import deepcopy\n",
8987
"from typing import Sequence\n",
90-
"from pydvl.influence.torch import EkfacInfluence\n",
88+
"\n",
89+
"import matplotlib.pyplot as plt\n",
90+
"import torch\n",
9191
"import torch.nn.functional as F\n",
92-
"from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
93-
"from copy import deepcopy\n",
92+
"from datasets import load_dataset\n",
9493
"from IPython.display import HTML, display\n",
95-
"import matplotlib.pyplot as plt"
94+
"from sklearn.metrics import f1_score\n",
95+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
96+
"\n",
97+
"from pydvl.influence.torch import EkfacInfluence\n",
98+
"from support.torch import ImdbDataset, ModelLogitsWrapper"
9699
]
97100
},
98101
{
@@ -156,8 +159,9 @@
156159
"name": "stderr",
157160
"output_type": "stream",
158161
"text": [
162+
"Using the latest cached version of the module from /Users/fabio/.cache/huggingface/modules/datasets_modules/datasets/imdb/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0 (last modified on Thu Dec 14 21:47:25 2023) since it couldn't be found locally at imdb., or remotely on the Hugging Face Hub.\n",
159163
"Found cached dataset imdb (/Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n",
160-
"100%|██████████| 3/3 [00:00<00:00, 136.16it/s]\n"
164+
"100%|██████████| 3/3 [00:00<00:00, 111.43it/s]\n"
161165
]
162166
}
163167
],
@@ -265,18 +269,11 @@
265269
"tokenized_example = tokenizer(\n",
266270
" [example_phrase],\n",
267271
" return_tensors=\"pt\",\n",
268-
" padding=True,\n",
269272
" truncation=True,\n",
270273
")\n",
271274
"\n",
272-
"tokenized_example_input_ids, tokenized_example_attention_mask = (\n",
273-
" tokenized_example.input_ids,\n",
274-
" tokenized_example.attention_mask,\n",
275-
")\n",
276-
"\n",
277275
"model_output = model(\n",
278-
" input_ids=tokenized_example_input_ids,\n",
279-
" attention_mask=tokenized_example_attention_mask,\n",
276+
" input_ids=tokenized_example.input_ids,\n",
280277
")"
281278
]
282279
},
@@ -322,13 +319,7 @@
322319
"metadata": {},
323320
"outputs": [],
324321
"source": [
325-
"model_predictions = F.softmax(\n",
326-
" model(\n",
327-
" input_ids=tokenized_example_input_ids,\n",
328-
" attention_mask=tokenized_example_attention_mask,\n",
329-
" )[\"logits\"],\n",
330-
" dim=1,\n",
331-
")"
322+
"model_predictions = F.softmax(model_output.logits, dim=1)"
332323
]
333324
},
334325
{
@@ -386,7 +377,7 @@
386377
"output_type": "stream",
387378
"text": [
388379
"Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c1eaa46e94dfbfd3.arrow\n",
389-
"Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c5cc0d728c27151c.arrow\n"
380+
"Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-5dd4cdcbaa0bcc93.arrow\n"
390381
]
391382
}
392383
],
@@ -402,7 +393,7 @@
402393
" logits = model(\n",
403394
" input_ids=sample_test_set[\"input_ids\"],\n",
404395
" attention_mask=sample_test_set[\"attention_mask\"],\n",
405-
" )[0]\n",
396+
" ).logits\n",
406397
" predictions = torch.argmax(logits, dim=1)"
407398
]
408399
},
@@ -435,7 +426,7 @@
435426
"cell_type": "markdown",
436427
"metadata": {},
437428
"source": [
438-
"In this section we will define several helper function and classes that will be used in the rest of the notebook. "
429+
"In this section we will define two helper function and classes that will be used in the rest of the notebook. "
439430
]
440431
},
441432
{
@@ -444,47 +435,6 @@
444435
"metadata": {},
445436
"outputs": [],
446437
"source": [
447-
"class ImdbDataset(torch.utils.data.Dataset):\n",
448-
" \"\"\"\n",
449-
" A PyTorch Dataset that takes in an HuggingFace Dataset object and tokenizes it.\n",
450-
" The objects returned by __getitem__ are PyTorch tensors, with x being a tuple of\n",
451-
" (input_ids, attention_mask), ready to be fed into a model, and y being the label.\n",
452-
" It also returns the original text, for printing and debugging purposes.\n",
453-
" \"\"\"\n",
454-
"\n",
455-
" def __init__(self, dataset):\n",
456-
" self.tokenized_ds = dataset.map(self.preprocess_function, batched=True)\n",
457-
" self.encodings = self.tokenized_ds[\"input_ids\"]\n",
458-
" self.attn_mask = self.tokenized_ds[\"attention_mask\"]\n",
459-
" self.labels = self.tokenized_ds[\"label\"]\n",
460-
"\n",
461-
" def preprocess_function(self, examples):\n",
462-
" return tokenizer(examples[\"text\"], truncation=True, padding=True)\n",
463-
"\n",
464-
" def __getitem__(self, idx):\n",
465-
" x = torch.tensor([self.encodings[idx], self.attn_mask[idx]])\n",
466-
" y = torch.tensor(self.labels[idx])\n",
467-
" text = self.tokenized_ds[idx][\"text\"]\n",
468-
" return x, y, text\n",
469-
"\n",
470-
" def __len__(self):\n",
471-
" return len(self.labels)\n",
472-
"\n",
473-
"\n",
474-
"class ModelLogitsWrapper(torch.nn.Module):\n",
475-
" \"\"\"\n",
476-
" A wrapper around a PyTorch model that returns only the logits and not the loss or\n",
477-
" the attention mask.\n",
478-
" \"\"\"\n",
479-
"\n",
480-
" def __init__(self, model):\n",
481-
" super().__init__()\n",
482-
" self.model = model\n",
483-
"\n",
484-
" def forward(self, x):\n",
485-
" return self.model(x[:, 0], x[:, 1])[\"logits\"]\n",
486-
"\n",
487-
"\n",
488438
"def print_sentiment_preds(\n",
489439
" model: ModelLogitsWrapper, model_input: torch.Tensor, true_label: int\n",
490440
"):\n",
@@ -620,8 +570,8 @@
620570
"text": [
621571
"Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9c48ce5d173413c7.arrow\n",
622572
"Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c1eaa46e94dfbfd3.arrow\n",
623-
"Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9aaaa3770ef3f9bf.arrow\n",
624-
"Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-7a8cbae367cafa72.arrow\n"
573+
" 0%| | 0/1 [00:00<?, ?ba/s]\n",
574+
" 0%| | 0/1 [00:00<?, ?ba/s]\n"
625575
]
626576
}
627577
],
@@ -638,8 +588,8 @@
638588
" imdb[\"test\"].shuffle(seed=seed).select([i for i in list(range(NUM_TEST_EXAMPLES))])\n",
639589
")\n",
640590
"\n",
641-
"train_dataset = ImdbDataset(small_train_dataset)\n",
642-
"test_dataset = ImdbDataset(small_test_dataset)\n",
591+
"train_dataset = ImdbDataset(small_train_dataset, tokenizer=tokenizer)\n",
592+
"test_dataset = ImdbDataset(small_test_dataset, tokenizer=tokenizer)\n",
643593
"\n",
644594
"train_dataloader = torch.utils.data.DataLoader(\n",
645595
" train_dataset, batch_size=7, shuffle=True\n",
@@ -663,14 +613,7 @@
663613
"name": "stderr",
664614
"output_type": "stream",
665615
"text": [
666-
"K-FAC blocks - batch progress: 0%| | 0/15 [00:00<?, ?it/s]"
667-
]
668-
},
669-
{
670-
"name": "stderr",
671-
"output_type": "stream",
672-
"text": [
673-
"K-FAC blocks - batch progress: 100%|██████████| 15/15 [01:59<00:00, 7.98s/it]\n"
616+
"K-FAC blocks - batch progress: 100%|██████████| 15/15 [01:52<00:00, 7.53s/it]\n"
674617
]
675618
}
676619
],
@@ -707,7 +650,7 @@
707650
"cell_type": "markdown",
708651
"metadata": {},
709652
"source": [
710-
"We calculate the influence of the first batch of training data over the first batch of test data. This because influence functions are very expensive to compute, and so to keep the runtime of this notebook within a few minutes we need to restrict ourselves a small number of examples."
653+
"We calculate the influence of the first batch of training data over the first batch of test data. This is because influence functions are very expensive to compute, and so to keep the runtime of this notebook within a few minutes we need to restrict ourselves to a small number of examples."
711654
]
712655
},
713656
{
@@ -925,14 +868,14 @@
925868
"cell_type": "markdown",
926869
"metadata": {},
927870
"source": [
928-
"This review is also quite hard to classify. This time it has a negative sentiment towards the movie, but it also contains several words with positive connotation. The parallel with the previous review is quite interesting, since both talk about an invasion. "
871+
"This review is also quite hard to classify. This time it has a negative sentiment towards the movie, but it also contains several words with positive connotation. The parallel with the previous review is quite interesting since both talk about an invasion. "
929872
]
930873
},
931874
{
932875
"cell_type": "markdown",
933876
"metadata": {},
934877
"source": [
935-
"As it is often the case when analysing influence functions, it is hard to understand why these examples have such a large influence. We have seen some interesting patterns, mostly related to similarities in the language and words used, but it is hard to say with certainty if these are the reasons for the large influence.\n",
878+
"As it is often the case when analysing influence functions, it is hard to understand why these examples have such a large influence. We have seen some interesting patterns, mostly related to similarities in the language and words used, but it is hard to say with certainty if these are the reasons for such a large influence.\n",
936879
"\n",
937880
"A [recent paper](https://arxiv.org/abs/2308.03296) has explored this topic in high detail, even for much larger language models than BERT (up to ~50 billion parameters!). Among the most interesting findings is that smaller models tend to rely a lot on word-to-word correspondencies, while larger models are more capable of extracting higher level concepts, drawing connections between words across multiple phrases.\n",
938881
"\n",

notebooks/support/torch.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,48 @@ def load(self) -> Losses:
255255
return pkl.load(file)
256256

257257

258+
class ImdbDataset(torch.utils.data.Dataset):
259+
"""
260+
A PyTorch Dataset that takes in an HuggingFace Dataset object and tokenizes it.
261+
The objects returned by __getitem__ are PyTorch tensors, with x being a tuple of
262+
(input_ids, attention_mask), ready to be fed into a model, and y being the label.
263+
It also returns the original text, for printing and debugging purposes.
264+
"""
265+
266+
def __init__(self, dataset, tokenizer):
267+
self.tokenizer = tokenizer
268+
self.tokenized_ds = dataset.map(self.preprocess_function, batched=True)
269+
self.encodings = self.tokenized_ds["input_ids"]
270+
self.attn_mask = self.tokenized_ds["attention_mask"]
271+
self.labels = self.tokenized_ds["label"]
272+
273+
def preprocess_function(self, examples):
274+
return self.tokenizer(examples["text"], truncation=True, padding=True)
275+
276+
def __getitem__(self, idx):
277+
x = torch.tensor([self.encodings[idx], self.attn_mask[idx]])
278+
y = torch.tensor(self.labels[idx])
279+
text = self.tokenized_ds[idx]["text"]
280+
return x, y, text
281+
282+
def __len__(self):
283+
return len(self.labels)
284+
285+
286+
class ModelLogitsWrapper(torch.nn.Module):
287+
"""
288+
A wrapper around a PyTorch model that returns only the logits and not the loss or
289+
the attention mask.
290+
"""
291+
292+
def __init__(self, model):
293+
super().__init__()
294+
self.model = model
295+
296+
def forward(self, x):
297+
return self.model(x[:, 0], x[:, 1]).logits
298+
299+
258300
def process_imgnet_io(
259301
df: pd.DataFrame, labels: dict
260302
) -> Tuple[torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)