Skip to content

Commit 77172b8

Browse files
committed
MNIST count and show number of errors.
This removes dependency on sklearn and also makes it more transparent. Additionally, people often report #errors on MNIST, i.e. the dropout paper.
1 parent 9053a98 commit 77172b8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/MNIST/test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import numpy as np
22
from progress_bar import *
33
import theano as _th
4-
from sklearn.metrics import accuracy_score
54

65
def validate(dataset_x, dataset_y, model, epoch, batch_size):
76
progress = make_progressbar('Testing', epoch, len(dataset_x))
87
progress.start()
98

109
mini_batch_input = np.empty(shape=(batch_size, 28*28), dtype=_th.config.floatX)
1110
mini_batch_targets = np.empty(shape=(batch_size, ), dtype=_th.config.floatX)
12-
accuracy = 0
11+
nerrors = 0
1312

1413
for j in range((dataset_x.shape[0] + batch_size - 1) // batch_size):
1514
progress.update(j * batch_size)
@@ -24,7 +23,8 @@ def validate(dataset_x, dataset_y, model, epoch, batch_size):
2423
mini_batch_prediction.resize((dataset_x.shape[0] - j * batch_size, ))
2524
mini_batch_targets.resize((dataset_x.shape[0] - j * batch_size, ))
2625

27-
accuracy = accuracy + accuracy_score(mini_batch_targets, mini_batch_prediction, normalize=False)
26+
nerrors += sum(mini_batch_targets != mini_batch_prediction)
2827

2928
progress.finish()
30-
print("Epoch #" + str(epoch) + ", Classification: " + str(float(accuracy) / dataset_x.shape[0] * 100.0))
29+
accuracy = 1 - float(nerrors)/dataset_x.shape[0]
30+
print("Epoch #{}, Classification accuracy: {:.2%} ({} errors)".format(epoch, accuracy, nerrors))

0 commit comments

Comments
 (0)