Skip to content

Commit 7a808ae

Browse files
hsirmMMathisLab
andauthored
Notebook Bug Resolved (#8)
* Notebook Bug Resolved --------- Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent 04a7c75 commit 7a808ae

File tree

2 files changed

+212
-77
lines changed

2 files changed

+212
-77
lines changed

Notebooks/AROS.ipynb

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,31 @@
33
{
44
"cell_type": "markdown",
55
"metadata": {
6-
"id": "view-in-github",
7-
"colab_type": "text"
6+
"colab_type": "text",
7+
"id": "view-in-github"
88
},
99
"source": [
1010
"<a href=\"https://colab.research.google.com/github/AdaptiveMotorControlLab/AROS/blob/main/Notebooks/AROS.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
1111
]
1212
},
1313
{
1414
"cell_type": "markdown",
15-
"source": [
16-
"## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings"
17-
],
1815
"metadata": {
1916
"id": "RFxEz28oe7dE"
20-
}
17+
},
18+
"source": [
19+
"## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings"
20+
]
2121
},
2222
{
2323
"cell_type": "markdown",
24+
"metadata": {
25+
"id": "ZL5Va1N940xJ"
26+
},
2427
"source": [
2528
"This notebook is designed to replicate and analyze the results presented in Table 1 of the AROS paper, focusing on out-of-distribution detection performance under both attack scenarios and clean evaluation. The dataset configurations involve using CIFAR-10 and CIFAR-100 as in-distribution and out-of-distribution datasets. The notebook is structured to load a pre-trained model as the encoder, followed by generating fake OOD embeddings through sampling. The model is then trained using the designed loss function and evaluated across various OOD detection benchmarks to assess its performance under different conditions.\n",
2629
"\n"
27-
],
28-
"metadata": {
29-
"id": "ZL5Va1N940xJ"
30-
}
30+
]
3131
},
3232
{
3333
"cell_type": "markdown",
@@ -40,41 +40,41 @@
4040
},
4141
{
4242
"cell_type": "code",
43-
"source": [
44-
"!git clone https://github.com/AdaptiveMotorControlLab/AROS.git"
45-
],
43+
"execution_count": null,
4644
"metadata": {
4745
"id": "TdY-7pyGq4oN"
4846
},
49-
"execution_count": null,
50-
"outputs": []
47+
"outputs": [],
48+
"source": [
49+
"!git clone https://github.com/MMathisLab/AROS.git"
50+
]
5151
},
5252
{
5353
"cell_type": "code",
54+
"execution_count": null,
55+
"metadata": {
56+
"id": "owrQtpTxrbth"
57+
},
58+
"outputs": [],
5459
"source": [
5560
"%cd /content/AROS\n",
5661
"%ls\n",
5762
"!pip install -r requirements.txt"
58-
],
59-
"metadata": {
60-
"id": "owrQtpTxrbth"
61-
},
62-
"execution_count": null,
63-
"outputs": []
63+
]
6464
},
6565
{
6666
"cell_type": "code",
67+
"execution_count": null,
68+
"metadata": {
69+
"id": "WgsBOHhNrtYD"
70+
},
71+
"outputs": [],
6772
"source": [
6873
"import argparse\n",
6974
"import torch\n",
7075
"import torch.nn as nn\n",
7176
"from tqdm.notebook import tqdm"
72-
],
73-
"metadata": {
74-
"id": "WgsBOHhNrtYD"
75-
},
76-
"execution_count": null,
77-
"outputs": []
77+
]
7878
},
7979
{
8080
"cell_type": "code",
@@ -112,6 +112,11 @@
112112
"\n",
113113
"# Define the hyperparameters controlled via CLI 'Ding2020MMA'\n",
114114
"\n",
115+
"\n",
116+
"parser.add_argument('--fast', type=bool, default=True, help='Toggle between fast and full fake data generation modes')\n",
117+
"parser.add_argument('--epoch1', type=int, default=2, help='Number of epochs for stage 1')\n",
118+
"parser.add_argument('--epoch2', type=int, default=1, help='Number of epochs for stage 2')\n",
119+
"parser.add_argument('--epoch3', type=int, default=2, help='Number of epochs for stage 3')\n",
115120
"parser.add_argument('--in_dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')\n",
116121
"parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')\n",
117122
"parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')\n",
@@ -474,9 +479,9 @@
474479
"accelerator": "GPU",
475480
"colab": {
476481
"gpuType": "A100",
477-
"provenance": [],
482+
"include_colab_link": true,
478483
"machine_shape": "hm",
479-
"include_colab_link": true
484+
"provenance": []
480485
},
481486
"kernelspec": {
482487
"display_name": "Python 3",

0 commit comments

Comments
 (0)