Skip to content

Commit a49334d

Browse files
Adding initial anomaly detection nb
1 parent e8434ff commit a49334d

File tree

1 file changed

+303
-0
lines changed

1 file changed

+303
-0
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "d313573c-0644-44ca-96ef-571ebbb3c250",
6+
"metadata": {},
7+
"source": [
8+
"# Anomaly Detection: MNIST vs. TF Flowers\n",
9+
"The following `Jupyter Notebook` explores the use of *anomaly detection*: first training a simple *autoencoder* (the fully connected `MinNDAE` model), and exploring the *reconstruction error*."
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"id": "799a59d8-7e7f-450c-9fd0-e21749b7cd75",
15+
"metadata": {},
16+
"source": [
17+
"## Setup\n",
18+
"Need to get the necessary packages ..."
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"id": "9f3b8122-297d-4662-acc7-d7a7b930243d",
25+
"metadata": {},
26+
"outputs": [],
27+
"source": [
28+
"# check for colab\n",
29+
"if \"google.colab\" in str(get_ipython()):\n",
30+
" # install colab dependencies\n",
31+
" !pip install git+https://github.com/DiogenesAnalytics/autoencoder"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"id": "dde3750d-4828-430b-9cc8-231066c37d35",
37+
"metadata": {},
38+
"source": [
39+
"## Get MNIST Data\n",
40+
"Wille use `keras.datasets` to get the `MNIST` dataset, and then do some *normalizing* and *reshaping* to prepare it for the *autoencoder*."
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": null,
46+
"id": "bc7c0488-abe6-453a-aae4-e3f9a392736a",
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"# get necessary libs for data/preprocessing\n",
51+
"import tensorflow as tf\n",
52+
"from keras.datasets import mnist\n",
53+
"\n",
54+
"# load the data\n",
55+
"(x_train, _), (x_test, _) = mnist.load_data()\n",
56+
"\n",
57+
"# preprocess the data (normalize)\n",
58+
"x_train = x_train.astype(\"float32\") / 255.\n",
59+
"x_test = x_test.astype(\"float32\") / 255.\n",
60+
"\n",
61+
"# add grayscale dimension\n",
62+
"x_train = tf.expand_dims(x_train, axis=-1)\n",
63+
"x_test = tf.expand_dims(x_test, axis=-1)\n",
64+
"\n",
65+
"# convert to tf datasets\n",
66+
"train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train))\n",
67+
"test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test))\n",
68+
"\n",
69+
"# set a few params\n",
70+
"BATCH_SIZE = 64\n",
71+
"SHUFFLE_BUFFER_SIZE = 100\n",
72+
"\n",
73+
"# update with batch/buffer size\n",
74+
"train_ds = train_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)\n",
75+
"test_ds = test_ds.batch(BATCH_SIZE)"
76+
]
77+
},
78+
{
79+
"cell_type": "markdown",
80+
"id": "0d588a1c-a082-405a-9b55-4399ff580879",
81+
"metadata": {},
82+
"source": [
83+
"## Get tf_flowers Data\n",
84+
"The [TensorFlow Flowers](https://www.tensorflow.org/datasets/catalog/tf_flowers) dataset first needs to be downloaded, and then preprocessed."
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"id": "03bec295-bd65-4f3b-9149-e82b311246b8",
91+
"metadata": {},
92+
"outputs": [],
93+
"source": [
94+
"# libs for tf flowers data\n",
95+
"import keras\n",
96+
"import pathlib\n",
97+
"\n",
98+
"# data location\n",
99+
"DATASET_URL = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n",
100+
"\n",
101+
"# download, get path, and convert to pathlib obj\n",
102+
"TF_FLOWERS_DATA_DIR = pathlib.Path(\n",
103+
" keras.utils.get_file(\"flower_photos\", origin=DATASET_URL, untar=True, cache_dir=\"./data/keras\")\n",
104+
")"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"id": "20ae5d2e-0209-4edb-ae66-ab5746ad278a",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"# get keras image dataset util func\n",
115+
"from keras.utils import image_dataset_from_directory\n",
116+
"\n",
117+
"# create normalization func\n",
118+
"def normalize(x):\n",
119+
" return x / 255.\n",
120+
"\n",
121+
"# use keras util to load raw images into tensorflow.data.Dataset\n",
122+
"anomalous_data = image_dataset_from_directory(\n",
123+
" TF_FLOWERS_DATA_DIR,\n",
124+
" labels=None,\n",
125+
" color_mode=\"grayscale\",\n",
126+
" validation_split=None,\n",
127+
" shuffle=True,\n",
128+
" subset=None,\n",
129+
" seed=42,\n",
130+
" image_size=(28, 28),\n",
131+
" batch_size=3670,\n",
132+
").map(normalize)"
133+
]
134+
},
135+
{
136+
"cell_type": "markdown",
137+
"id": "ee2e329b-1c9d-4e09-af91-b38e93e6613f",
138+
"metadata": {},
139+
"source": [
140+
"## Autoencoder Training\n",
141+
"Finally the *autoencoder* can be trained ..."
142+
]
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": null,
147+
"id": "2d440a72-b357-4794-a5b9-f4b4ee790524",
148+
"metadata": {},
149+
"outputs": [],
150+
"source": [
151+
"# get libs for training ae\n",
152+
"from autoencoder.model.minimal import MinNDAE, MinNDParams\n",
153+
"\n",
154+
"# seupt config\n",
155+
"config = MinNDParams(\n",
156+
" l0={\"input_shape\": (28, 28, 1)},\n",
157+
" l2={\"units\": 32 * 1},\n",
158+
" l3={\"units\": 28 * 28 * 1},\n",
159+
" l4={\"target_shape\": (28, 28, 1)},\n",
160+
")\n",
161+
"\n",
162+
"# get ae instance\n",
163+
"autoencoder = MinNDAE(config)\n",
164+
"\n",
165+
"# check network topology\n",
166+
"autoencoder.summary()"
167+
]
168+
},
169+
{
170+
"cell_type": "code",
171+
"execution_count": null,
172+
"id": "da883534-e913-494b-bafe-564d8b25f76c",
173+
"metadata": {},
174+
"outputs": [],
175+
"source": [
176+
"# get code for callbacks and custom loss function\n",
177+
"from autoencoder.training import build_anomaly_loss_function\n",
178+
"from keras.callbacks import EarlyStopping\n",
179+
"\n",
180+
"# create callback\n",
181+
"early_stop_callback = EarlyStopping(monitor=\"val_anomaly_diff\", patience=2)\n",
182+
"\n",
183+
"# get custom loss func\n",
184+
"custom_loss = build_anomaly_loss_function(next(iter(anomalous_data)), autoencoder)\n",
185+
"\n",
186+
"# compile ae\n",
187+
"autoencoder.compile(\n",
188+
" optimizer=\"adam\",\n",
189+
" loss=custom_loss,\n",
190+
" metrics=[custom_loss],\n",
191+
")\n",
192+
"\n",
193+
"# begin model fit\n",
194+
"autoencoder.fit(\n",
195+
" x=train_ds,\n",
196+
" epochs=10**2,\n",
197+
" validation_data=test_ds,\n",
198+
" callbacks=[early_stop_callback],\n",
199+
")"
200+
]
201+
},
202+
{
203+
"cell_type": "code",
204+
"execution_count": null,
205+
"id": "2ece10e9-bcd3-44f9-a15d-bf2ed9b06e4f",
206+
"metadata": {},
207+
"outputs": [],
208+
"source": [
209+
"# view training loss\n",
210+
"autoencoder.training_history()"
211+
]
212+
},
213+
{
214+
"cell_type": "markdown",
215+
"id": "a4629952-4d4e-411e-9841-a38e9787ca43",
216+
"metadata": {},
217+
"source": [
218+
"## Reconstruction Error Distribution\n",
219+
"Now let us take peak into this dataset and see how well the *autoencoder* is working as an *anomaly detector* (i.e. how **low** vs. how **high** the *reconstruction* error is for the training and anomalous datasets respectively)."
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"id": "a782719c-b51b-414d-a8e5-483a9efde43b",
226+
"metadata": {},
227+
"outputs": [],
228+
"source": [
229+
"# get custom anomaly detection class\n",
230+
"from autoencoder.data.anomaly import AnomalyDetector\n",
231+
"\n",
232+
"# get mnist instance\n",
233+
"mnist_recon_error = AnomalyDetector(autoencoder, test_ds, axis=(1, 2, 3))\n",
234+
"\n",
235+
"# calculate recon error\n",
236+
"mnist_recon_error.calculate_error()"
237+
]
238+
},
239+
{
240+
"cell_type": "code",
241+
"execution_count": null,
242+
"id": "aa9a70d6-9e58-42ea-9602-27660316292f",
243+
"metadata": {},
244+
"outputs": [],
245+
"source": [
246+
"# get tf flowers instance\n",
247+
"tfflower_recon_error = AnomalyDetector(autoencoder, anomalous_data)\n",
248+
"\n",
249+
"# calculate recon error\n",
250+
"tfflower_recon_error.calculate_error()"
251+
]
252+
},
253+
{
254+
"cell_type": "code",
255+
"execution_count": null,
256+
"id": "d7d09932-87c7-4643-90ab-d20fb1174ff8",
257+
"metadata": {},
258+
"outputs": [],
259+
"source": [
260+
"# turn on interactive plot\n",
261+
"%matplotlib widget"
262+
]
263+
},
264+
{
265+
"cell_type": "code",
266+
"execution_count": null,
267+
"id": "3a6c9f15-2ca1-405a-930a-a944ccd21e13",
268+
"metadata": {},
269+
"outputs": [],
270+
"source": [
271+
"# now compare recon error distributions\n",
272+
"mnist_recon_error.histogram(\n",
273+
" \"MNIST Anomaly Detection Using TF Flowers: MinNDAE\",\n",
274+
" label=\"mnist\",\n",
275+
" bins=[100, 100],\n",
276+
" additional_data=[tfflower_recon_error], \n",
277+
" additional_labels=[\"tf_flowers\"],\n",
278+
")"
279+
]
280+
}
281+
],
282+
"metadata": {
283+
"kernelspec": {
284+
"display_name": "Python 3 (ipykernel)",
285+
"language": "python",
286+
"name": "python3"
287+
},
288+
"language_info": {
289+
"codemirror_mode": {
290+
"name": "ipython",
291+
"version": 3
292+
},
293+
"file_extension": ".py",
294+
"mimetype": "text/x-python",
295+
"name": "python",
296+
"nbconvert_exporter": "python",
297+
"pygments_lexer": "ipython3",
298+
"version": "3.10.11"
299+
}
300+
},
301+
"nbformat": 4,
302+
"nbformat_minor": 5
303+
}

0 commit comments

Comments
 (0)