|
3 | 3 | { |
4 | 4 | "cell_type": "markdown", |
5 | 5 | "metadata": { |
6 | | - "id": "view-in-github", |
7 | | - "colab_type": "text" |
| 6 | + "colab_type": "text", |
| 7 | + "id": "view-in-github" |
8 | 8 | }, |
9 | 9 | "source": [ |
10 | 10 | "<a href=\"https://colab.research.google.com/github/AdaptiveMotorControlLab/AROS/blob/main/Notebooks/Ablation_Study.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
11 | 11 | ] |
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "markdown", |
15 | | - "source": [ |
16 | | - "## AROS, Ablation Study" |
17 | | - ], |
18 | 15 | "metadata": { |
19 | 16 | "id": "G1Ues10_fww5" |
20 | | - } |
| 17 | + }, |
| 18 | + "source": [ |
| 19 | + "## AROS, Ablation Study" |
| 20 | + ] |
| 21 | + }, |
| 22 | + { |
| 23 | + "cell_type": "code", |
| 24 | + "execution_count": null, |
| 25 | + "metadata": {}, |
| 26 | + "outputs": [], |
| 27 | + "source": [ |
| 28 | + "!git clone https://github.com/AdaptiveMotorControlLab/AROS" |
| 29 | + ] |
| 30 | + }, |
| 31 | + { |
| 32 | + "cell_type": "code", |
| 33 | + "execution_count": null, |
| 34 | + "metadata": {}, |
| 35 | + "outputs": [], |
| 36 | + "source": [ |
| 37 | + "!pip install -r ./AROS/requirements.txt\n", |
| 38 | + "cd ./AROS/AROS" |
| 39 | + ] |
| 40 | + }, |
| 41 | + { |
| 42 | + "cell_type": "code", |
| 43 | + "execution_count": null, |
| 44 | + "metadata": {}, |
| 45 | + "outputs": [], |
| 46 | + "source": [ |
| 47 | + "import argparse\n", |
| 48 | + "import torch\n", |
| 49 | + "import torch.nn as nn\n", |
| 50 | + "from tqdm.notebook import tqdm" |
| 51 | + ] |
21 | 52 | }, |
22 | 53 | { |
23 | 54 | "cell_type": "code", |
|
91 | 122 | } |
92 | 123 | ], |
93 | 124 | "source": [ |
94 | | - "!pip install -r requirements.txt\n", |
95 | | - "import argparse\n", |
96 | | - "import torch\n", |
97 | | - "import torch.nn as nn\n", |
98 | 125 | "from evaluate import *\n", |
99 | 126 | "from utils import *\n", |
100 | 127 | "from tqdm.notebook import tqdm\n", |
|
112 | 139 | "source": [ |
113 | 140 | "parser = argparse.ArgumentParser(description=\"Hyperparameters for the script\")\n", |
114 | 141 | "\n", |
115 | | - "# Define the hyperparameters controlled via CLI 'Ding2020MMA'\n", |
116 | | - "\n", |
| 142 | + " \n", |
117 | 143 | "parser.add_argument('--in_dataset', type=str, default='cifar100', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')\n", |
118 | 144 | "parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')\n", |
119 | 145 | "parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')\n", |
|
144 | 170 | "cell_type": "code", |
145 | 171 | "execution_count": null, |
146 | 172 | "metadata": { |
147 | | - "id": "g2TltXvg7MfF", |
148 | | - "outputId": "4df864e7-e14b-4db4-e1ae-06e33c9b11be", |
149 | 173 | "colab": { |
150 | 174 | "referenced_widgets": [ |
151 | 175 | "59296a90b8c84b1c94648a4c5d68a43b", |
152 | 176 | "ad54c341af6e400280d000b3725f08ee" |
153 | 177 | ] |
154 | | - } |
| 178 | + }, |
| 179 | + "id": "g2TltXvg7MfF", |
| 180 | + "outputId": "4df864e7-e14b-4db4-e1ae-06e33c9b11be" |
155 | 181 | }, |
156 | 182 | "outputs": [ |
157 | 183 | { |
|
217 | 243 | "cell_type": "code", |
218 | 244 | "execution_count": null, |
219 | 245 | "metadata": { |
220 | | - "id": "QeC-30C5ImKg", |
221 | | - "outputId": "793be5f3-3307-4a3d-8e5f-177ac212d30a", |
222 | 246 | "colab": { |
223 | 247 | "referenced_widgets": [ |
224 | 248 | "c9c97585bef049ca9974797d1d5964ab", |
225 | 249 | "97f0832ab970458f947318195735214b" |
226 | 250 | ] |
227 | | - } |
| 251 | + }, |
| 252 | + "id": "QeC-30C5ImKg", |
| 253 | + "outputId": "793be5f3-3307-4a3d-8e5f-177ac212d30a" |
228 | 254 | }, |
229 | 255 | "outputs": [ |
230 | 256 | { |
|
0 commit comments