Skip to content

Commit 58c5cf6

Browse files
Python mnist examples fix return value
1 parent a4e1fcc commit 58c5cf6

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

py/htm/examples/mnist.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import random
2020
import numpy as np
2121
import os
22+
import sys
2223

2324
from htm.bindings.algorithms import SpatialPooler, Classifier
2425
from htm.bindings.sdr import SDR, Metrics
@@ -73,7 +74,7 @@ def load_images(file_name):
7374

7475
return train_labels, train_images, test_labels, test_images
7576

76-
# these parameters can be improved using parameter optimization,
77+
# These parameters can be improved using parameter optimization,
7778
# see py/htm/optimization/ae.py
7879
# For more explanation of relations between the parameters, see
7980
# src/examples/mnist/MNIST_CPP.cpp
@@ -114,7 +115,7 @@ def main(parameters=default_parameters, argv=None, verbose=True):
114115
potentialPct = parameters['potentialPct'],
115116
globalInhibition = True,
116117
localAreaDensity = parameters['localAreaDensity'],
117-
stimulusThreshold = int(round(parameters['stimulusThreshold'])), #param is requested to be an integer, but param optimization might find fractional value, so round it
118+
stimulusThreshold = int(round(parameters['stimulusThreshold'])),
118119
synPermInactiveDec = parameters['synPermInactiveDec'],
119120
synPermActiveInc = parameters['synPermActiveInc'],
120121
synPermConnected = parameters['synPermConnected'],
@@ -145,13 +146,11 @@ def main(parameters=default_parameters, argv=None, verbose=True):
145146
sp.compute( enc, False, columns )
146147
if lbl == np.argmax( sdrc.infer( columns ) ):
147148
score += 1
148-
149149
score = score / len(test_data)
150150

151-
assert score >= 0.951, "MNIST: score should be better than 95.1%"
152151
print('Score:', 100 * score, '%')
153-
return score
152+
return score < 0.95
154153

155154

156155
if __name__ == '__main__':
157-
main()
156+
sys.exit( main() )

0 commit comments

Comments
 (0)