Skip to content

Commit 5e2b195

Browse files
authored
Merge pull request #499 from cogmission/zero_overlap_tests
Add 4 SP Zero Overlap Tests & Parameter KEY docs
2 parents 79a96db + f0a4b6e commit 5e2b195

File tree

5 files changed

+265
-23
lines changed

5 files changed

+265
-23
lines changed

src/main/java/org/numenta/nupic/Parameters.java

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ public class Parameters implements Persistable {
107107
defaultSpatialParams.put(KEY.LOCAL_AREA_DENSITY, -1.0);
108108
defaultSpatialParams.put(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 10.0);
109109
defaultSpatialParams.put(KEY.STIMULUS_THRESHOLD, 0.0);
110-
defaultSpatialParams.put(KEY.SYN_PERM_INACTIVE_DEC, 0.008);//0.01
111-
defaultSpatialParams.put(KEY.SYN_PERM_ACTIVE_INC, 0.05);//0.1
110+
defaultSpatialParams.put(KEY.SYN_PERM_INACTIVE_DEC, 0.008);
111+
defaultSpatialParams.put(KEY.SYN_PERM_ACTIVE_INC, 0.05);
112112
defaultSpatialParams.put(KEY.SYN_PERM_CONNECTED, 0.10);
113113
defaultSpatialParams.put(KEY.SYN_PERM_BELOW_STIMULUS_INC, 0.01);
114114
defaultSpatialParams.put(KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
@@ -223,7 +223,7 @@ public static enum KEY {
223223
* predicted but inactive segments are decremented.
224224
*/
225225
PREDICTED_SEGMENT_DECREMENT("predictedSegmentDecrement", Double.class, 0.0, 9.0),
226-
/** Remove this and add Logging (slf4j) */
226+
/** TODO: Remove this and add Logging (slf4j) */
227227
//TM_VERBOSITY("tmVerbosity", Integer.class, 0, 10),
228228

229229

@@ -234,21 +234,136 @@ public static enum KEY {
234234
* using the Network API (which sets this automatically)
235235
*/
236236
POTENTIAL_RADIUS("potentialRadius", Integer.class),
237+
/**
238+
* The percent of the inputs, within a column's potential radius, that a
239+
* column can be connected to. If set to 1, the column will be connected
240+
* to every input within its potential radius. This parameter is used to
241+
* give each column a unique potential pool when a large potentialRadius
242+
* causes overlap between the columns. At initialization time we choose
243+
* ((2*potentialRadius + 1)^(# inputDimensions) * potentialPct) input bits
244+
* to comprise the column's potential pool.
245+
*/
237246
POTENTIAL_PCT("potentialPct", Double.class), //TODO add range here?
247+
/**
248+
* If true, then during inhibition phase the winning columns are selected
249+
* as the most active columns from the region as a whole. Otherwise, the
250+
* winning columns are selected with respect to their local neighborhoods.
251+
* Using global inhibition boosts performance x60.
252+
*/
238253
GLOBAL_INHIBITION("globalInhibition", Boolean.class),
254+
/**
255+
* The inhibition radius determines the size of a column's local
256+
* neighborhood. A cortical column must overcome the overlap score of
257+
* columns in its neighborhood in order to become active. This radius is
258+
* updated every learning round. It grows and shrinks with the average
259+
* number of connected synapses per column.
260+
*/
239261
INHIBITION_RADIUS("inhibitionRadius", Integer.class, 0, null),
262+
/**
263+
* The desired density of active columns within a local inhibition area
264+
* (the size of which is set by the internally calculated inhibitionRadius,
265+
* which is in turn determined from the average size of the connected
266+
* potential pools of all columns). The inhibition logic will insure that
267+
* at most N columns remain ON within a local inhibition area, where
268+
* N = localAreaDensity * (total number of columns in inhibition area).
269+
*/
240270
LOCAL_AREA_DENSITY("localAreaDensity", Double.class), //TODO add range here?
271+
/**
272+
* An alternate way to control the density of the active columns. If
273+
* numActiveColumnsPerInhArea is specified then localAreaDensity must be
274+
* less than 0, and vice versa. When using numActiveColumnsPerInhArea, the
275+
* inhibition logic will insure that at most 'numActiveColumnsPerInhArea'
276+
* columns remain ON within a local inhibition area (the size of which is
277+
* set by the internally calculated inhibitionRadius, which is in turn
278+
* determined from the average size of the connected receptive fields of all
279+
* columns). When using this method, as columns learn and grow their
280+
* effective receptive fields, the inhibitionRadius will grow, and hence the
281+
* net density of the active columns will *decrease*. This is in contrast to
282+
* the localAreaDensity method, which keeps the density of active columns
283+
* the same regardless of the size of their receptive fields.
284+
*/
241285
NUM_ACTIVE_COLUMNS_PER_INH_AREA("numActiveColumnsPerInhArea", Double.class),//TODO add range here?
286+
/**
287+
* This is a number specifying the minimum number of synapses that must be
288+
* on in order for a columns to turn ON. The purpose of this is to prevent
289+
* noise input from activating columns. Specified as a percent of a fully
290+
* grown synapse.
291+
*/
242292
STIMULUS_THRESHOLD("stimulusThreshold", Double.class), //TODO add range here?
293+
/**
294+
* The amount by which an inactive synapse is decremented in each round.
295+
* Specified as a percent of a fully grown synapse.
296+
*/
243297
SYN_PERM_INACTIVE_DEC("synPermInactiveDec", Double.class, 0.0, 1.0),
298+
/**
299+
* The amount by which an active synapse is incremented in each round.
300+
* Specified as a percent of a fully grown synapse.
301+
*/
244302
SYN_PERM_ACTIVE_INC("synPermActiveInc", Double.class, 0.0, 1.0),
303+
/**
304+
* The default connected threshold. Any synapse whose permanence value is
305+
* above the connected threshold is a "connected synapse", meaning it can
306+
* contribute to the cell's firing.
307+
*/
245308
SYN_PERM_CONNECTED("synPermConnected", Double.class, 0.0, 1.0),
309+
/**
310+
* <b>WARNING:</b> This is a <i><b>derived</b><i> value, and is overwritten
311+
* by the SpatialPooler algorithm's initialization.
312+
*
313+
* The permanence increment amount for columns that have not been
314+
* recently active
315+
*/
246316
SYN_PERM_BELOW_STIMULUS_INC("synPermBelowStimulusInc", Double.class, 0.0, 1.0),
317+
/**
318+
* <b>WARNING:</b> This is a <i><b>derived</b><i> value, and is overwritten
319+
* by the SpatialPooler algorithm's initialization.
320+
*
321+
* Values below this are "clipped" and zero'd out.
322+
*/
247323
SYN_PERM_TRIM_THRESHOLD("synPermTrimThreshold", Double.class, 0.0, 1.0),
324+
/**
325+
* A number between 0 and 1.0, used to set a floor on how often a column
326+
* should have at least stimulusThreshold active inputs. Periodically, each
327+
* column looks at the overlap duty cycle of all other columns within its
328+
* inhibition radius and sets its own internal minimal acceptable duty cycle
329+
* to: minPctDutyCycleBeforeInh * max(other columns' duty cycles). On each
330+
* iteration, any column whose overlap duty cycle falls below this computed
331+
* value will get all of its permanence values boosted up by
332+
* synPermActiveInc. Raising all permanences in response to a sub-par duty
333+
* cycle before inhibition allows a cell to search for new inputs when
334+
* either its previously learned inputs are no longer ever active, or when
335+
* the vast majority of them have been "hijacked" by other columns.
336+
*/
248337
MIN_PCT_OVERLAP_DUTY_CYCLES("minPctOverlapDutyCycles", Double.class),//TODO add range here?
338+
/**
339+
* A number between 0 and 1.0, used to set a floor on how often a column
340+
* should be activate. Periodically, each column looks at the activity duty
341+
* cycle of all other columns within its inhibition radius and sets its own
342+
* internal minimal acceptable duty cycle to: minPctDutyCycleAfterInh *
343+
* max(other columns' duty cycles). On each iteration, any column whose duty
344+
* cycle after inhibition falls below this computed value will get its
345+
* internal boost factor increased.
346+
*/
249347
MIN_PCT_ACTIVE_DUTY_CYCLES("minPctActiveDutyCycles", Double.class),//TODO add range here?
348+
/**
349+
* The period used to calculate duty cycles. Higher values make it take
350+
* longer to respond to changes in boost or synPerConnectedCell. Shorter
351+
* values make it more unstable and likely to oscillate.
352+
*/
250353
DUTY_CYCLE_PERIOD("dutyCyclePeriod", Integer.class),//TODO add range here?
354+
/**
355+
* The maximum overlap boost factor. Each column's overlap gets multiplied
356+
* by a boost factor before it gets considered for inhibition. The actual
357+
* boost factor for a column is number between 1.0 and maxBoost. A boost
358+
* factor of 1.0 is used if the duty cycle is >= minOverlapDutyCycle,
359+
* maxBoost is used if the duty cycle is 0, and any duty cycle in between is
360+
* linearly extrapolated from these 2 endpoints.
361+
*/
251362
MAX_BOOST("maxBoost", Double.class), //TODO add range here?
363+
/**
364+
* Determines if inputs at the beginning and end of an input dimension should
365+
* be considered neighbors when mapping columns to inputs.
366+
*/
252367
WRAP_AROUND("wrapAround", Boolean.class),
253368

254369
///////////// SpatialPooler / Network Parameter(s) /////////////

src/main/java/org/numenta/nupic/algorithms/SpatialPooler.java

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
package org.numenta.nupic.algorithms;
2323

2424
import java.util.Arrays;
25-
import java.util.Comparator;
2625
import java.util.stream.IntStream;
2726

2827
import org.numenta.nupic.model.Column;
@@ -808,28 +807,23 @@ public int[] inhibitColumns(Connections c, double[] overlaps) {
808807
public int[] inhibitColumnsGlobal(Connections c, double[] overlaps, double density) {
809808
int numCols = c.getNumColumns();
810809
int numActive = (int)(density * numCols);
811-
812-
Comparator<Pair<Integer, Double>> comparator =
813-
(p1, p2) -> {
814-
int p1key = p1.getFirst();
815-
int p2key = p2.getFirst();
816-
double p1val = p1.getSecond();
817-
double p2val = p2.getSecond();
818-
if(Math.abs(p2val - p1val) < 0.000000001) {
819-
return Math.abs(p2key - p1key) < 0.000000001 ? 0 : p2key > p1key ? 1 : -1;
820-
} else {
821-
return p2val > p1val ? 1 : -1;
822-
}
823-
};
824-
int[] inhibit = IntStream.range(0,overlaps.length)
810+
811+
int[] sortedWinnerIndices = IntStream.range(0,overlaps.length)
825812
.mapToObj(i-> new Pair<>(i,overlaps[i]))
826-
.sorted(comparator)
813+
.sorted(c.inhibitionComparator)
827814
.mapToInt(Pair<Integer,Double>::getFirst)
828-
.limit(numActive)
829-
.sorted()
830815
.toArray();
831816

832-
return inhibit;
817+
// Enforce the stimulus threshold
818+
double stimulusThreshold = c.getStimulusThreshold();
819+
int start = sortedWinnerIndices.length - numActive;
820+
while(start < sortedWinnerIndices.length) {
821+
int i = sortedWinnerIndices[start];
822+
if(overlaps[i] >= stimulusThreshold) break;
823+
++start;
824+
}
825+
826+
return IntStream.of(sortedWinnerIndices).skip(start).toArray();
833827
}
834828

835829
/**

src/main/java/org/numenta/nupic/model/Connections.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.numenta.nupic.util.Topology;
3232
import org.numenta.nupic.util.UniversalRandom;
3333

34+
import chaschev.lang.Pair;
3435
import gnu.trove.list.array.TIntArrayList;
3536

3637
/**
@@ -231,6 +232,19 @@ public class Connections implements Persistable {
231232
return c1 == c2 ? 0 : c1 > c2 ? 1 : -1;
232233
};
233234

235+
/** Sorting Lambda used for SpatialPooler inhibition */
236+
public Comparator<Pair<Integer, Double>> inhibitionComparator = (Comparator<Pair<Integer, Double>> & Serializable)
237+
(p1, p2) -> {
238+
int p1key = p1.getFirst();
239+
int p2key = p2.getFirst();
240+
double p1val = p1.getSecond();
241+
double p2val = p2.getSecond();
242+
if(Math.abs(p2val - p1val) < 0.000000001) {
243+
return Math.abs(p2key - p1key) < 0.000000001 ? 0 : p2key > p1key ? -1 : 1;
244+
} else {
245+
return p2val > p1val ? -1 : 1;
246+
}
247+
};
234248

235249
////////////////////////////////////////
236250
// Connections Constructor //

src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,123 @@ public int[] inhibitColumns(Connections c, double[] overlaps) {
226226
}
227227
}
228228

229+
/**
230+
* When stimulusThreshold is 0, allow columns without any overlap to become
231+
* active. This test focuses on the global inhibition code path.
232+
*/
233+
@Test
234+
public void testZeroOverlap_NoStimulusThreshold_GlobalInhibition() {
235+
int inputSize = 10;
236+
int nColumns = 20;
237+
parameters = Parameters.getSpatialDefaultParameters();
238+
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
239+
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
240+
parameters.set(KEY.POTENTIAL_RADIUS, 10);
241+
parameters.set(KEY.GLOBAL_INHIBITION, true);
242+
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 3.0);
243+
parameters.set(KEY.STIMULUS_THRESHOLD, 0.0);
244+
parameters.set(KEY.RANDOM, new UniversalRandom(42));
245+
parameters.set(KEY.SEED, 42);
246+
247+
SpatialPooler sp = new SpatialPooler();
248+
Connections cn = new Connections();
249+
parameters.apply(cn);
250+
sp.init(cn);
251+
252+
int[] activeArray = new int[nColumns];
253+
sp.compute(cn, new int[inputSize], activeArray, true);
254+
255+
assertEquals(3, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
256+
}
257+
258+
/**
259+
* When stimulusThreshold is > 0, don't allow columns without any overlap to
260+
* become active. This test focuses on the global inhibition code path.
261+
*/
262+
@Test
263+
public void testZeroOverlap_StimulusThreshold_GlobalInhibition() {
264+
int inputSize = 10;
265+
int nColumns = 20;
266+
parameters = Parameters.getSpatialDefaultParameters();
267+
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
268+
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
269+
parameters.set(KEY.POTENTIAL_RADIUS, 10);
270+
parameters.set(KEY.GLOBAL_INHIBITION, true);
271+
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 3.0);
272+
parameters.set(KEY.STIMULUS_THRESHOLD, 1.0);
273+
parameters.set(KEY.RANDOM, new UniversalRandom(42));
274+
parameters.set(KEY.SEED, 42);
275+
276+
SpatialPooler sp = new SpatialPooler();
277+
Connections cn = new Connections();
278+
parameters.apply(cn);
279+
sp.init(cn);
280+
281+
int[] activeArray = new int[nColumns];
282+
sp.compute(cn, new int[inputSize], activeArray, true);
283+
284+
assertEquals(0, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
285+
}
286+
287+
@Test
288+
public void testZeroOverlap_NoStimulusThreshold_LocalInhibition() {
289+
int inputSize = 10;
290+
int nColumns = 20;
291+
parameters = Parameters.getSpatialDefaultParameters();
292+
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
293+
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
294+
parameters.set(KEY.POTENTIAL_RADIUS, 5);
295+
parameters.set(KEY.GLOBAL_INHIBITION, false);
296+
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 1.0);
297+
parameters.set(KEY.STIMULUS_THRESHOLD, 0.0);
298+
parameters.set(KEY.RANDOM, new UniversalRandom(42));
299+
parameters.set(KEY.SEED, 42);
300+
301+
SpatialPooler sp = new SpatialPooler();
302+
Connections cn = new Connections();
303+
parameters.apply(cn);
304+
sp.init(cn);
305+
306+
// This exact number of active columns is determined by the inhibition
307+
// radius, which changes based on the random synapses (i.e. weird math).
308+
// Force it to a known number.
309+
cn.setInhibitionRadius(2);
310+
311+
int[] activeArray = new int[nColumns];
312+
sp.compute(cn, new int[inputSize], activeArray, true);
313+
314+
assertEquals(6, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
315+
}
316+
317+
/**
318+
* When stimulusThreshold is > 0, don't allow columns without any overlap to
319+
* become active. This test focuses on the local inhibition code path.
320+
*/
321+
@Test
322+
public void testZeroOverlap_StimulusThreshold_LocalInhibition() {
323+
int inputSize = 10;
324+
int nColumns = 20;
325+
parameters = Parameters.getSpatialDefaultParameters();
326+
parameters.set(KEY.INPUT_DIMENSIONS, new int[] { inputSize });
327+
parameters.set(KEY.COLUMN_DIMENSIONS, new int[] { nColumns });
328+
parameters.set(KEY.POTENTIAL_RADIUS, 10);
329+
parameters.set(KEY.GLOBAL_INHIBITION, false);
330+
parameters.set(KEY.NUM_ACTIVE_COLUMNS_PER_INH_AREA, 3.0);
331+
parameters.set(KEY.STIMULUS_THRESHOLD, 1.0);
332+
parameters.set(KEY.RANDOM, new UniversalRandom(42));
333+
parameters.set(KEY.SEED, 42);
334+
335+
SpatialPooler sp = new SpatialPooler();
336+
Connections cn = new Connections();
337+
parameters.apply(cn);
338+
sp.init(cn);
339+
340+
int[] activeArray = new int[nColumns];
341+
sp.compute(cn, new int[inputSize], activeArray, true);
342+
343+
assertEquals(0, ArrayUtils.where(activeArray, ArrayUtils.INT_GREATER_THAN_0).length);
344+
}
345+
229346
@Test
230347
public void testOverlapsOutput() {
231348
parameters = Parameters.getSpatialDefaultParameters();
@@ -1558,6 +1675,7 @@ public void testInhibitColumnsGlobal() {
15581675
double[] overlaps = new double[] { 1, 2, 1, 4, 8, 3, 12, 5, 4, 1 };
15591676
int[] active = sp.inhibitColumnsGlobal(mem, overlaps, density);
15601677
int[] trueActive = new int[] { 4, 6, 7 };
1678+
Arrays.sort(active);
15611679
assertTrue(Arrays.equals(trueActive, active));
15621680

15631681
density = 0.5;

src/test/java/org/numenta/nupic/network/RegionTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ public void testMultiLayerAssemblyNoSensor() {
321321
p.set(KEY.SYN_PERM_INACTIVE_DEC, 0.1);
322322
p.set(KEY.SYN_PERM_ACTIVE_INC, 0.1);
323323
p.set(KEY.SYN_PERM_TRIM_THRESHOLD, 0.05);
324+
p.set(KEY.GLOBAL_INHIBITION, true);
324325
p.set(KEY.SYN_PERM_CONNECTED, 0.4);
325326
p.set(KEY.MAX_BOOST, 10.0);
326327
p.set(KEY.DUTY_CYCLE_PERIOD, 7);
@@ -358,7 +359,7 @@ public void testMultiLayerAssemblyNoSensor() {
358359
}
359360
});
360361

361-
final int NUM_CYCLES = 400;
362+
final int NUM_CYCLES = 500;
362363
final int INPUT_GROUP_COUNT = 7; // Days of Week
363364
Map<String, Object> multiInput = new HashMap<>();
364365
for(int i = 0;i < NUM_CYCLES;i++) {

0 commit comments

Comments
 (0)