1717
1818import argparse
1919import random
20- import gzip
2120import numpy as np
2221import os
22+ import sys
2323
2424from htm .bindings .algorithms import SpatialPooler , Classifier
2525from htm .bindings .sdr import SDR , Metrics
@@ -36,7 +36,7 @@ def int32(b):
3636 return i
3737
3838 def load_labels (file_name ):
39- with gzip . open (file_name , 'rb' ) as f :
39+ with open (file_name , 'rb' ) as f :
4040 raw = f .read ()
4141 assert (int32 (raw [0 :4 ]) == 2049 ) # Magic number
4242 labels = []
@@ -46,7 +46,7 @@ def load_labels(file_name):
4646 return labels
4747
4848 def load_images (file_name ):
49- with gzip . open (file_name , 'rb' ) as f :
49+ with open (file_name , 'rb' ) as f :
5050 raw = f .read ()
5151 assert (int32 (raw [0 :4 ]) == 2051 ) # Magic number
5252 num_imgs = int32 (raw [4 :8 ])
@@ -67,32 +67,36 @@ def load_images(file_name):
6767 assert (len (raw ) == data_start + img_size * num_imgs ) # All data should be used.
6868 return imgs
6969
70- train_labels = load_labels (os .path .join (path , 'train-labels-idx1-ubyte.gz ' ))
71- train_images = load_images (os .path .join (path , 'train-images-idx3-ubyte.gz ' ))
72- test_labels = load_labels (os .path .join (path , 't10k-labels-idx1-ubyte.gz ' ))
73- test_images = load_images (os .path .join (path , 't10k-images-idx3-ubyte.gz ' ))
70+ train_labels = load_labels (os .path .join (path , 'train-labels-idx1-ubyte' ))
71+ train_images = load_images (os .path .join (path , 'train-images-idx3-ubyte' ))
72+ test_labels = load_labels (os .path .join (path , 't10k-labels-idx1-ubyte' ))
73+ test_images = load_images (os .path .join (path , 't10k-images-idx3-ubyte' ))
7474
7575 return train_labels , train_images , test_labels , test_images
7676
77-
77+ # These parameters can be improved using parameter optimization,
78+ # see py/htm/optimization/ae.py
79+ # For more explanation of relations between the parameters, see
80+ # src/examples/mnist/MNIST_CPP.cpp
7881default_parameters = {
79- 'boostStrength' : 7.80643753517375 ,
80- 'columnDimensions' : (35415 ,1 ),
81- 'dutyCyclePeriod' : 1321 ,
82- 'localAreaDensity' : 0.05361688506086096 ,
83- 'minPctOverlapDutyCycle' : 0.0016316043362658 ,
84- 'potentialPct' : 0.06799785776775163 ,
85- 'stimulusThreshold' : 8 ,
86- 'synPermActiveInc' : 0.01455789388651146 ,
87- 'synPermConnected' : 0.021649964738697944 ,
88- 'synPermInactiveDec' : 0.006442691852205935
82+ 'potentialRadius' : 7 ,
83+ 'boostStrength' : 7.0 ,
84+ 'columnDimensions' : (28 * 28 * 8 , 1 ),
85+ 'dutyCyclePeriod' : 1402 ,
86+ 'localAreaDensity' : 0.1 ,
87+ 'minPctOverlapDutyCycle' : 0.2 ,
88+ 'potentialPct' : 0.1 ,
89+ 'stimulusThreshold' : 6 ,
90+ 'synPermActiveInc' : 0.14 ,
91+ 'synPermConnected' : 0.5 ,
92+ 'synPermInactiveDec' : 0.02
8993}
9094
9195
9296def main (parameters = default_parameters , argv = None , verbose = True ):
9397 parser = argparse .ArgumentParser ()
9498 parser .add_argument ('--data_dir' , type = str ,
95- default = os .path .join ( os .path .dirname (__file__ ), 'MNIST_data ' ))
99+ default = os .path .join ( os .path .dirname (__file__ ), '..' , '..' , '..' , 'build' , 'ThirdParty' , 'mnist_data' , 'mnist-src ' ))
96100 args = parser .parse_args (args = argv )
97101
98102 # Load data.
@@ -107,11 +111,10 @@ def main(parameters=default_parameters, argv=None, verbose=True):
107111 sp = SpatialPooler (
108112 inputDimensions = enc .dimensions ,
109113 columnDimensions = parameters ['columnDimensions' ],
110- potentialRadius = 99999999 ,
114+ potentialRadius = parameters [ 'potentialRadius' ] ,
111115 potentialPct = parameters ['potentialPct' ],
112116 globalInhibition = True ,
113117 localAreaDensity = parameters ['localAreaDensity' ],
114- numActiveColumnsPerInhArea = - 1 ,
115118 stimulusThreshold = int (round (parameters ['stimulusThreshold' ])),
116119 synPermInactiveDec = parameters ['synPermInactiveDec' ],
117120 synPermActiveInc = parameters ['synPermActiveInc' ],
@@ -143,10 +146,11 @@ def main(parameters=default_parameters, argv=None, verbose=True):
143146 sp .compute ( enc , False , columns )
144147 if lbl == np .argmax ( sdrc .infer ( columns ) ):
145148 score += 1
149+ score = score / len (test_data )
146150
147- print ('Score:' , 100 * score / len ( test_data ) , '%' )
148- return score / len ( test_data )
151+ print ('Score:' , 100 * score , '%' )
152+ return score < 0.95
149153
150154
151155if __name__ == '__main__' :
152- main ()
156+ sys . exit ( main () )
0 commit comments