Skip to content

Commit ec0dc3a

Browse files
Adding 3rd dimension to mnist data
1 parent a49334d commit ec0dc3a

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

notebooks/demo/mnist_dataset.ipynb

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858
"x_train = x_train.astype(\"float32\") / 255.\n",
5959
"x_test = x_test.astype(\"float32\") / 255.\n",
6060
"\n",
61+
"# add grayscale dimension\n",
62+
"x_train = tf.expand_dims(x_train, axis=-1)\n",
63+
"x_test = tf.expand_dims(x_test, axis=-1)\n",
64+
"\n",
6165
"# convert to tf datasets\n",
6266
"train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train))\n",
6367
"test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test))\n",
@@ -92,10 +96,10 @@
9296
"\n",
9397
"# seupt config\n",
9498
"config = MinNDParams(\n",
95-
" l0={\"input_shape\": (28, 28)},\n",
99+
" l0={\"input_shape\": (28, 28, 1)},\n",
96100
" l2={\"units\": 32 * 1},\n",
97101
" l3={\"units\": 28 * 28 * 1},\n",
98-
" l4={\"target_shape\": (28, 28)},\n",
102+
" l4={\"target_shape\": (28, 28, 1)},\n",
99103
")\n",
100104
"\n",
101105
"# get ae instance\n",
@@ -116,10 +120,10 @@
116120
"from keras.callbacks import EarlyStopping\n",
117121
"\n",
118122
"# create callback\n",
119-
"early_stop_callback = EarlyStopping(monitor=\"val_loss\", patience=3)\n",
123+
"early_stop_callback = EarlyStopping(monitor=\"val_loss\", patience=2)\n",
120124
"\n",
121125
"# compile ae\n",
122-
"autoencoder.compile(optimizer=\"adam\", loss=\"binary_crossentropy\")\n",
126+
"autoencoder.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n",
123127
"\n",
124128
"# begin model fit\n",
125129
"autoencoder.fit(\n",
@@ -204,7 +208,7 @@
204208
"from autoencoder.data.anomaly import AnomalyDetector\n",
205209
"\n",
206210
"# get instance\n",
207-
"mnist_recon_error = AnomalyDetector(autoencoder, test_ds, axis=(1, 2))\n",
211+
"mnist_recon_error = AnomalyDetector(autoencoder, test_ds, axis=(1, 2, 3))\n",
208212
"\n",
209213
"# calculate recon error\n",
210214
"mnist_recon_error.calculate_error()\n",

0 commit comments

Comments
 (0)