|
3 | 3 | { |
4 | 4 | "cell_type": "markdown", |
5 | 5 | "metadata": { |
6 | | - "id": "view-in-github", |
7 | | - "colab_type": "text" |
| 6 | + "colab_type": "text", |
| 7 | + "id": "view-in-github" |
8 | 8 | }, |
9 | 9 | "source": [ |
10 | 10 | "<a href=\"https://colab.research.google.com/github/AdaptiveMotorControlLab/AROS/blob/main/Notebooks/AROS.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
11 | 11 | ] |
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "markdown", |
15 | | - "source": [ |
16 | | - "## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings" |
17 | | - ], |
18 | 15 | "metadata": { |
19 | 16 | "id": "RFxEz28oe7dE" |
20 | | - } |
| 17 | + }, |
| 18 | + "source": [ |
| 19 | + "## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings" |
| 20 | + ] |
21 | 21 | }, |
22 | 22 | { |
23 | 23 | "cell_type": "markdown", |
| 24 | + "metadata": { |
| 25 | + "id": "ZL5Va1N940xJ" |
| 26 | + }, |
24 | 27 | "source": [ |
25 | 28 | "This notebook is designed to replicate and analyze the results presented in Table 1 of the AROS paper, focusing on out-of-distribution detection performance under both attack scenarios and clean evaluation. The dataset configurations involve using CIFAR-10 and CIFAR-100 as in-distribution and out-of-distribution datasets. The notebook is structured to load a pre-trained model as the encoder, followed by generating fake OOD embeddings through sampling. The model is then trained using the designed loss function and evaluated across various OOD detection benchmarks to assess its performance under different conditions.\n", |
26 | 29 | "\n" |
27 | | - ], |
28 | | - "metadata": { |
29 | | - "id": "ZL5Va1N940xJ" |
30 | | - } |
| 30 | + ] |
31 | 31 | }, |
32 | 32 | { |
33 | 33 | "cell_type": "markdown", |
|
40 | 40 | }, |
41 | 41 | { |
42 | 42 | "cell_type": "code", |
43 | | - "source": [ |
44 | | - "!git clone https://github.com/AdaptiveMotorControlLab/AROS.git" |
45 | | - ], |
| 43 | + "execution_count": null, |
46 | 44 | "metadata": { |
47 | 45 | "id": "TdY-7pyGq4oN" |
48 | 46 | }, |
49 | | - "execution_count": null, |
50 | | - "outputs": [] |
| 47 | + "outputs": [], |
| 48 | + "source": [ |
| 49 | + "!git clone https://github.com/MMathisLab/AROS.git" |
| 50 | + ] |
51 | 51 | }, |
52 | 52 | { |
53 | 53 | "cell_type": "code", |
| 54 | + "execution_count": null, |
| 55 | + "metadata": { |
| 56 | + "id": "owrQtpTxrbth" |
| 57 | + }, |
| 58 | + "outputs": [], |
54 | 59 | "source": [ |
55 | 60 | "%cd /content/AROS\n", |
56 | 61 | "%ls\n", |
57 | 62 | "!pip install -r requirements.txt" |
58 | | - ], |
59 | | - "metadata": { |
60 | | - "id": "owrQtpTxrbth" |
61 | | - }, |
62 | | - "execution_count": null, |
63 | | - "outputs": [] |
| 63 | + ] |
64 | 64 | }, |
65 | 65 | { |
66 | 66 | "cell_type": "code", |
| 67 | + "execution_count": null, |
| 68 | + "metadata": { |
| 69 | + "id": "WgsBOHhNrtYD" |
| 70 | + }, |
| 71 | + "outputs": [], |
67 | 72 | "source": [ |
68 | 73 | "import argparse\n", |
69 | 74 | "import torch\n", |
70 | 75 | "import torch.nn as nn\n", |
71 | 76 | "from tqdm.notebook import tqdm" |
72 | | - ], |
73 | | - "metadata": { |
74 | | - "id": "WgsBOHhNrtYD" |
75 | | - }, |
76 | | - "execution_count": null, |
77 | | - "outputs": [] |
| 77 | + ] |
78 | 78 | }, |
79 | 79 | { |
80 | 80 | "cell_type": "code", |
|
112 | 112 | "\n", |
113 | 113 | "# Define the hyperparameters controlled via CLI 'Ding2020MMA'\n", |
114 | 114 | "\n", |
| 115 | + "\n", |
| 116 | + "parser.add_argument('--fast', type=bool, default=True, help='Toggle between fast and full fake data generation modes')\n", |
| 117 | + "parser.add_argument('--epoch1', type=int, default=2, help='Number of epochs for stage 1')\n", |
| 118 | + "parser.add_argument('--epoch2', type=int, default=1, help='Number of epochs for stage 2')\n", |
| 119 | + "parser.add_argument('--epoch3', type=int, default=2, help='Number of epochs for stage 3')\n", |
115 | 120 | "parser.add_argument('--in_dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')\n", |
116 | 121 | "parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')\n", |
117 | 122 | "parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')\n", |
|
474 | 479 | "accelerator": "GPU", |
475 | 480 | "colab": { |
476 | 481 | "gpuType": "A100", |
477 | | - "provenance": [], |
| 482 | + "include_colab_link": true, |
478 | 483 | "machine_shape": "hm", |
479 | | - "include_colab_link": true |
| 484 | + "provenance": [] |
480 | 485 | }, |
481 | 486 | "kernelspec": { |
482 | 487 | "display_name": "Python 3", |
|
0 commit comments