|
83 | 83 | } |
84 | 84 | ], |
85 | 85 | "source": [ |
86 | | - "from datasets import load_dataset\n", |
87 | | - "import torch\n", |
88 | | - "from sklearn.metrics import f1_score\n", |
| 86 | + "from copy import deepcopy\n", |
89 | 87 | "from typing import Sequence\n", |
90 | | - "from pydvl.influence.torch import EkfacInfluence\n", |
| 88 | + "\n", |
| 89 | + "import matplotlib.pyplot as plt\n", |
| 90 | + "import torch\n", |
91 | 91 | "import torch.nn.functional as F\n", |
92 | | - "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", |
93 | | - "from copy import deepcopy\n", |
| 92 | + "from datasets import load_dataset\n", |
94 | 93 | "from IPython.display import HTML, display\n", |
95 | | - "import matplotlib.pyplot as plt" |
| 94 | + "from sklearn.metrics import f1_score\n", |
| 95 | + "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", |
| 96 | + "\n", |
| 97 | + "from pydvl.influence.torch import EkfacInfluence\n", |
| 98 | + "from support.torch import ImdbDataset, ModelLogitsWrapper" |
96 | 99 | ] |
97 | 100 | }, |
98 | 101 | { |
|
156 | 159 | "name": "stderr", |
157 | 160 | "output_type": "stream", |
158 | 161 | "text": [ |
| 162 | + "Using the latest cached version of the module from /Users/fabio/.cache/huggingface/modules/datasets_modules/datasets/imdb/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0 (last modified on Thu Dec 14 21:47:25 2023) since it couldn't be found locally at imdb., or remotely on the Hugging Face Hub.\n", |
159 | 163 | "Found cached dataset imdb (/Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n", |
160 | | - "100%|██████████| 3/3 [00:00<00:00, 136.16it/s]\n" |
| 164 | + "100%|██████████| 3/3 [00:00<00:00, 111.43it/s]\n" |
161 | 165 | ] |
162 | 166 | } |
163 | 167 | ], |
|
265 | 269 | "tokenized_example = tokenizer(\n", |
266 | 270 | " [example_phrase],\n", |
267 | 271 | " return_tensors=\"pt\",\n", |
268 | | - " padding=True,\n", |
269 | 272 | " truncation=True,\n", |
270 | 273 | ")\n", |
271 | 274 | "\n", |
272 | | - "tokenized_example_input_ids, tokenized_example_attention_mask = (\n", |
273 | | - " tokenized_example.input_ids,\n", |
274 | | - " tokenized_example.attention_mask,\n", |
275 | | - ")\n", |
276 | | - "\n", |
277 | 275 | "model_output = model(\n", |
278 | | - " input_ids=tokenized_example_input_ids,\n", |
279 | | - " attention_mask=tokenized_example_attention_mask,\n", |
| 276 | + " input_ids=tokenized_example.input_ids,\n", |
280 | 277 | ")" |
281 | 278 | ] |
282 | 279 | }, |
|
322 | 319 | "metadata": {}, |
323 | 320 | "outputs": [], |
324 | 321 | "source": [ |
325 | | - "model_predictions = F.softmax(\n", |
326 | | - " model(\n", |
327 | | - " input_ids=tokenized_example_input_ids,\n", |
328 | | - " attention_mask=tokenized_example_attention_mask,\n", |
329 | | - " )[\"logits\"],\n", |
330 | | - " dim=1,\n", |
331 | | - ")" |
| 322 | + "model_predictions = F.softmax(model_output.logits, dim=1)" |
332 | 323 | ] |
333 | 324 | }, |
334 | 325 | { |
|
386 | 377 | "output_type": "stream", |
387 | 378 | "text": [ |
388 | 379 | "Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c1eaa46e94dfbfd3.arrow\n", |
389 | | - "Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c5cc0d728c27151c.arrow\n" |
| 380 | + "Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-5dd4cdcbaa0bcc93.arrow\n" |
390 | 381 | ] |
391 | 382 | } |
392 | 383 | ], |
|
402 | 393 | " logits = model(\n", |
403 | 394 | " input_ids=sample_test_set[\"input_ids\"],\n", |
404 | 395 | " attention_mask=sample_test_set[\"attention_mask\"],\n", |
405 | | - " )[0]\n", |
| 396 | + " ).logits\n", |
406 | 397 | " predictions = torch.argmax(logits, dim=1)" |
407 | 398 | ] |
408 | 399 | }, |
|
435 | 426 | "cell_type": "markdown", |
436 | 427 | "metadata": {}, |
437 | 428 | "source": [ |
438 | | - "In this section we will define several helper function and classes that will be used in the rest of the notebook. " |
| 429 | + "In this section we will define two helper function and classes that will be used in the rest of the notebook. " |
439 | 430 | ] |
440 | 431 | }, |
441 | 432 | { |
|
444 | 435 | "metadata": {}, |
445 | 436 | "outputs": [], |
446 | 437 | "source": [ |
447 | | - "class ImdbDataset(torch.utils.data.Dataset):\n", |
448 | | - " \"\"\"\n", |
449 | | - " A PyTorch Dataset that takes in an HuggingFace Dataset object and tokenizes it.\n", |
450 | | - " The objects returned by __getitem__ are PyTorch tensors, with x being a tuple of\n", |
451 | | - " (input_ids, attention_mask), ready to be fed into a model, and y being the label.\n", |
452 | | - " It also returns the original text, for printing and debugging purposes.\n", |
453 | | - " \"\"\"\n", |
454 | | - "\n", |
455 | | - " def __init__(self, dataset):\n", |
456 | | - " self.tokenized_ds = dataset.map(self.preprocess_function, batched=True)\n", |
457 | | - " self.encodings = self.tokenized_ds[\"input_ids\"]\n", |
458 | | - " self.attn_mask = self.tokenized_ds[\"attention_mask\"]\n", |
459 | | - " self.labels = self.tokenized_ds[\"label\"]\n", |
460 | | - "\n", |
461 | | - " def preprocess_function(self, examples):\n", |
462 | | - " return tokenizer(examples[\"text\"], truncation=True, padding=True)\n", |
463 | | - "\n", |
464 | | - " def __getitem__(self, idx):\n", |
465 | | - " x = torch.tensor([self.encodings[idx], self.attn_mask[idx]])\n", |
466 | | - " y = torch.tensor(self.labels[idx])\n", |
467 | | - " text = self.tokenized_ds[idx][\"text\"]\n", |
468 | | - " return x, y, text\n", |
469 | | - "\n", |
470 | | - " def __len__(self):\n", |
471 | | - " return len(self.labels)\n", |
472 | | - "\n", |
473 | | - "\n", |
474 | | - "class ModelLogitsWrapper(torch.nn.Module):\n", |
475 | | - " \"\"\"\n", |
476 | | - " A wrapper around a PyTorch model that returns only the logits and not the loss or\n", |
477 | | - " the attention mask.\n", |
478 | | - " \"\"\"\n", |
479 | | - "\n", |
480 | | - " def __init__(self, model):\n", |
481 | | - " super().__init__()\n", |
482 | | - " self.model = model\n", |
483 | | - "\n", |
484 | | - " def forward(self, x):\n", |
485 | | - " return self.model(x[:, 0], x[:, 1])[\"logits\"]\n", |
486 | | - "\n", |
487 | | - "\n", |
488 | 438 | "def print_sentiment_preds(\n", |
489 | 439 | " model: ModelLogitsWrapper, model_input: torch.Tensor, true_label: int\n", |
490 | 440 | "):\n", |
|
620 | 570 | "text": [ |
621 | 571 | "Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9c48ce5d173413c7.arrow\n", |
622 | 572 | "Loading cached shuffled indices for dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-c1eaa46e94dfbfd3.arrow\n", |
623 | | - "Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9aaaa3770ef3f9bf.arrow\n", |
624 | | - "Loading cached processed dataset at /Users/fabio/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-7a8cbae367cafa72.arrow\n" |
| 573 | + " 0%| | 0/1 [00:00<?, ?ba/s]\n", |
| 574 | + " 0%| | 0/1 [00:00<?, ?ba/s]\n" |
625 | 575 | ] |
626 | 576 | } |
627 | 577 | ], |
|
638 | 588 | " imdb[\"test\"].shuffle(seed=seed).select([i for i in list(range(NUM_TEST_EXAMPLES))])\n", |
639 | 589 | ")\n", |
640 | 590 | "\n", |
641 | | - "train_dataset = ImdbDataset(small_train_dataset)\n", |
642 | | - "test_dataset = ImdbDataset(small_test_dataset)\n", |
| 591 | + "train_dataset = ImdbDataset(small_train_dataset, tokenizer=tokenizer)\n", |
| 592 | + "test_dataset = ImdbDataset(small_test_dataset, tokenizer=tokenizer)\n", |
643 | 593 | "\n", |
644 | 594 | "train_dataloader = torch.utils.data.DataLoader(\n", |
645 | 595 | " train_dataset, batch_size=7, shuffle=True\n", |
|
663 | 613 | "name": "stderr", |
664 | 614 | "output_type": "stream", |
665 | 615 | "text": [ |
666 | | - "K-FAC blocks - batch progress: 0%| | 0/15 [00:00<?, ?it/s]" |
667 | | - ] |
668 | | - }, |
669 | | - { |
670 | | - "name": "stderr", |
671 | | - "output_type": "stream", |
672 | | - "text": [ |
673 | | - "K-FAC blocks - batch progress: 100%|██████████| 15/15 [01:59<00:00, 7.98s/it]\n" |
| 616 | + "K-FAC blocks - batch progress: 100%|██████████| 15/15 [01:52<00:00, 7.53s/it]\n" |
674 | 617 | ] |
675 | 618 | } |
676 | 619 | ], |
|
707 | 650 | "cell_type": "markdown", |
708 | 651 | "metadata": {}, |
709 | 652 | "source": [ |
710 | | - "We calculate the influence of the first batch of training data over the first batch of test data. This because influence functions are very expensive to compute, and so to keep the runtime of this notebook within a few minutes we need to restrict ourselves a small number of examples." |
| 653 | + "We calculate the influence of the first batch of training data over the first batch of test data. This is because influence functions are very expensive to compute, and so to keep the runtime of this notebook within a few minutes we need to restrict ourselves to a small number of examples." |
711 | 654 | ] |
712 | 655 | }, |
713 | 656 | { |
|
925 | 868 | "cell_type": "markdown", |
926 | 869 | "metadata": {}, |
927 | 870 | "source": [ |
928 | | - "This review is also quite hard to classify. This time it has a negative sentiment towards the movie, but it also contains several words with positive connotation. The parallel with the previous review is quite interesting, since both talk about an invasion. " |
| 871 | + "This review is also quite hard to classify. This time it has a negative sentiment towards the movie, but it also contains several words with positive connotation. The parallel with the previous review is quite interesting since both talk about an invasion. " |
929 | 872 | ] |
930 | 873 | }, |
931 | 874 | { |
932 | 875 | "cell_type": "markdown", |
933 | 876 | "metadata": {}, |
934 | 877 | "source": [ |
935 | | - "As it is often the case when analysing influence functions, it is hard to understand why these examples have such a large influence. We have seen some interesting patterns, mostly related to similarities in the language and words used, but it is hard to say with certainty if these are the reasons for the large influence.\n", |
| 878 | + "As it is often the case when analysing influence functions, it is hard to understand why these examples have such a large influence. We have seen some interesting patterns, mostly related to similarities in the language and words used, but it is hard to say with certainty if these are the reasons for such a large influence.\n", |
936 | 879 | "\n", |
937 | 880 | "A [recent paper](https://arxiv.org/abs/2308.03296) has explored this topic in high detail, even for much larger language models than BERT (up to ~50 billion parameters!). Among the most interesting findings is that smaller models tend to rely a lot on word-to-word correspondencies, while larger models are more capable of extracting higher level concepts, drawing connections between words across multiple phrases.\n", |
938 | 881 | "\n", |
|
0 commit comments