|
58 | 58 | "x_train = x_train.astype(\"float32\") / 255.\n", |
59 | 59 | "x_test = x_test.astype(\"float32\") / 255.\n", |
60 | 60 | "\n", |
| 61 | + "# add grayscale dimension\n", |
| 62 | + "x_train = tf.expand_dims(x_train, axis=-1)\n", |
| 63 | + "x_test = tf.expand_dims(x_test, axis=-1)\n", |
| 64 | + "\n", |
61 | 65 | "# convert to tf datasets\n", |
62 | 66 | "train_ds = tf.data.Dataset.from_tensor_slices((x_train, x_train))\n", |
63 | 67 | "test_ds = tf.data.Dataset.from_tensor_slices((x_test, x_test))\n", |
|
92 | 96 | "\n", |
93 | 97 | "# seupt config\n", |
94 | 98 | "config = MinNDParams(\n", |
95 | | - " l0={\"input_shape\": (28, 28)},\n", |
| 99 | + " l0={\"input_shape\": (28, 28, 1)},\n", |
96 | 100 | " l2={\"units\": 32 * 1},\n", |
97 | 101 | " l3={\"units\": 28 * 28 * 1},\n", |
98 | | - " l4={\"target_shape\": (28, 28)},\n", |
| 102 | + " l4={\"target_shape\": (28, 28, 1)},\n", |
99 | 103 | ")\n", |
100 | 104 | "\n", |
101 | 105 | "# get ae instance\n", |
|
116 | 120 | "from keras.callbacks import EarlyStopping\n", |
117 | 121 | "\n", |
118 | 122 | "# create callback\n", |
119 | | - "early_stop_callback = EarlyStopping(monitor=\"val_loss\", patience=3)\n", |
| 123 | + "early_stop_callback = EarlyStopping(monitor=\"val_loss\", patience=2)\n", |
120 | 124 | "\n", |
121 | 125 | "# compile ae\n", |
122 | | - "autoencoder.compile(optimizer=\"adam\", loss=\"binary_crossentropy\")\n", |
| 126 | + "autoencoder.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n", |
123 | 127 | "\n", |
124 | 128 | "# begin model fit\n", |
125 | 129 | "autoencoder.fit(\n", |
|
204 | 208 | "from autoencoder.data.anomaly import AnomalyDetector\n", |
205 | 209 | "\n", |
206 | 210 | "# get instance\n", |
207 | | - "mnist_recon_error = AnomalyDetector(autoencoder, test_ds, axis=(1, 2))\n", |
| 211 | + "mnist_recon_error = AnomalyDetector(autoencoder, test_ds, axis=(1, 2, 3))\n", |
208 | 212 | "\n", |
209 | 213 | "# calculate recon error\n", |
210 | 214 | "mnist_recon_error.calculate_error()\n", |
|
0 commit comments