Skip to content

Commit 44b2c4f

Browse files
Merge pull request #555 from htm-community/mnist_baseline
MNIST baseline model: no SP, directly raw images to classifier, 90.5%
2 parents e35be37 + e8a47d1 commit 44b2c4f

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

src/examples/mnist/MNIST_SP.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class MNIST {
5656
* Score: 95.35% (465 / 10000 wrong) : 28x28x16 : 2 : 125 : : smaller boosting (2.0)
5757
* -- this will be my working model, reasonable performance/speed ratio
5858
*
59+
* Baseline:
60+
* Score: 90.52% (948 / 10000 wrong). : SP disabled : 1 : 0.489 : 01a6c90297 : baseline with only classifier on raw images, on SP
61+
*
5962
*/
6063

6164
private:
@@ -67,7 +70,7 @@ class MNIST {
6770

6871
public:
6972
UInt verbosity = 1;
70-
const UInt train_dataset_iterations = 2u; //epochs somewhat help, at linear time
73+
const UInt train_dataset_iterations = 1u; //epochs somewhat help, at linear time
7174

7275

7376
void setup() {
@@ -104,7 +107,13 @@ void setup() {
104107
mnist::binarize_dataset(dataset);
105108
}
106109

107-
void train() {
110+
111+
/**
112+
* train the SP on the training set.
113+
* @param skipSP bool (default false) if set, output directly the input to the classifier.
114+
* This is used for a baseline benchmark (Classifier directly learns on input images)
115+
*/
116+
void train(const bool skipSP=false) {
108117
// Train
109118

110119
if(verbosity)
@@ -133,8 +142,9 @@ void train() {
133142

134143
// Compute & Train
135144
input.setDense( image );
136-
sp.compute(input, true, columns);
137-
clsr.learn( columns, {label} );
145+
if(not skipSP)
146+
sp.compute(input, true, columns);
147+
clsr.learn( skipSP ? input : columns, {label} );
138148
if( verbosity && (++i % 1000 == 0) ) cout << "." << flush;
139149
}
140150
if( verbosity ) cout << endl;
@@ -154,7 +164,7 @@ void train() {
154164
dump.close();
155165
}
156166

157-
void test() {
167+
void test(const bool skipSP=false) {
158168
// Test
159169
Real score = 0;
160170
UInt n_samples = 0;
@@ -167,9 +177,11 @@ void test() {
167177

168178
// Compute
169179
input.setDense( image );
170-
sp.compute(input, false, columns);
180+
if(not skipSP)
181+
sp.compute(input, false, columns);
182+
171183
// Check results
172-
if( argmax( clsr.infer( columns ) ) == label)
184+
if( argmax( clsr.infer( skipSP ? input : columns ) ) == label)
173185
score += 1;
174186
n_samples += 1;
175187
if( verbosity && i % 1000 == 0 ) cout << "." << flush;
@@ -185,6 +197,11 @@ void test() {
185197
int main(int argc, char **argv) {
186198
MNIST m;
187199
m.setup();
200+
cout << "===========BASELINE: no SP====================" << endl;
201+
m.train(true); //skip SP learning
202+
m.test(true);
203+
cout << "===========Spatial Pooler=====================" << endl;
204+
m.setup();
188205
m.train();
189206
m.test();
190207

src/htm/algorithms/SpatialPooler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ void SpatialPooler::initialize(
425425
spVerbosity_ = spVerbosity;
426426
wrapAround_ = wrapAround;
427427
updatePeriod_ = 50u;
428-
initConnectedPct_ = 0.5f;
428+
initConnectedPct_ = 0.5f; //FIXME make SP's param, and much lower 0.01 https://discourse.numenta.org/t/spatial-pooler-implementation-for-mnist-dataset/2317/25?u=breznak
429429
iterationNum_ = 0u;
430430
iterationLearnNum_ = 0u;
431431

0 commit comments

Comments
 (0)