Skip to content

Commit bec5f6f

Browse files
authored
Merge pull request #511 from Hopding/napi-sdrclassifier-integration
NAPI SDRClassifier Integration
2 parents a58515a + a3c62b7 commit bec5f6f

File tree

12 files changed

+276
-48
lines changed

12 files changed

+276
-48
lines changed

src/main/java/org/numenta/nupic/Parameters.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,19 @@
2323
package org.numenta.nupic;
2424

2525
import java.io.IOException;
26-
import java.util.Arrays;
2726
import java.util.Collections;
28-
import java.util.EnumMap;
2927
import java.util.HashMap;
30-
import java.util.List;
3128
import java.util.Map;
29+
import java.util.List;
3230
import java.util.Random;
3331
import java.util.Set;
32+
import java.util.EnumMap;
33+
import java.util.Arrays;
3434

3535
import org.numenta.nupic.algorithms.SpatialPooler;
3636
import org.numenta.nupic.algorithms.TemporalMemory;
3737
import org.numenta.nupic.model.Cell;
38+
import org.numenta.nupic.model.Segment;
3839
import org.numenta.nupic.model.Column;
3940
import org.numenta.nupic.model.ComputeCycle;
4041
import org.numenta.nupic.model.DistalDendrite;
@@ -417,8 +418,10 @@ public static enum KEY {
417418

418419
// Network Layer indicator for auto classifier generation
419420
AUTO_CLASSIFY("hasClassifiers", Boolean.class),
420-
421-
421+
422+
/** Maps encoder input field name to type of classifier to be used for them */
423+
INFERRED_FIELDS("inferredFields", Map.class), // Map<String, Classifier.class>
424+
422425
// How many bits to use if encoding the respective date fields.
423426
// e.g. Tuple(bits to use:int, radius:double)
424427
DATEFIELD_SEASON("season", Tuple.class),

src/main/java/org/numenta/nupic/algorithms/CLAClassifier.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
* @author David Ray
6565
* @see BitHistory
6666
*/
67-
public class CLAClassifier implements Persistable {
67+
public class CLAClassifier implements Persistable, Classifier {
6868
private static final long serialVersionUID = 1L;
6969

7070
int verbosity = 0;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package org.numenta.nupic.algorithms;
2+
3+
import java.util.Map;
4+
5+
/**
6+
* Classifier is an interface for Classifier types used to predict future inputs
7+
* to the system, such as {@link CLAClassifier} or {@link SDRClassifier}.
8+
*/
9+
public interface Classifier {
10+
public <T> Classification<T> compute(int recordNum,
11+
Map<String, Object> classification,
12+
int[] patternNZ,
13+
boolean learn,
14+
boolean infer);
15+
}

src/main/java/org/numenta/nupic/algorithms/SDRClassifier.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
* @author David Ray
9797
* @author Andrew Dillon
9898
*/
99-
public class SDRClassifier implements Persistable {
99+
public class SDRClassifier implements Persistable, Classifier {
100100
private static final long serialVersionUID = 1L;
101101

102102
int verbosity = 0;

src/main/java/org/numenta/nupic/network/Layer.java

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
import org.numenta.nupic.FieldMetaType;
3737
import org.numenta.nupic.Parameters;
3838
import org.numenta.nupic.Parameters.KEY;
39-
import org.numenta.nupic.algorithms.Anomaly;
40-
import org.numenta.nupic.algorithms.CLAClassifier;
4139
import org.numenta.nupic.algorithms.Classification;
42-
import org.numenta.nupic.algorithms.SpatialPooler;
4340
import org.numenta.nupic.algorithms.TemporalMemory;
41+
import org.numenta.nupic.algorithms.SpatialPooler;
42+
import org.numenta.nupic.algorithms.Anomaly;
43+
import org.numenta.nupic.algorithms.Classifier;
44+
import org.numenta.nupic.algorithms.SDRClassifier;
45+
import org.numenta.nupic.algorithms.CLAClassifier;
4446
import org.numenta.nupic.encoders.DateEncoder;
4547
import org.numenta.nupic.encoders.Encoder;
4648
import org.numenta.nupic.encoders.EncoderTuple;
@@ -231,7 +233,7 @@ public class Layer<T> implements Persistable {
231233
private boolean hasGenericProcess;
232234

233235
/**
234-
* List of {@link Encoders} used when storing bucket information see
236+
* List of {@link Encoder}s used when storing bucket information see
235237
* {@link #doEncoderBucketMapping(Inference, Map)}
236238
*/
237239
private List<EncoderTuple> encoderTuples;
@@ -399,7 +401,7 @@ public Layer(Parameters params, MultiEncoder e, SpatialPooler sp, TemporalMemory
399401
(encoder == null ? "" : "MultiEncoder,"),
400402
(spatialPooler == null ? "" : "SpatialPooler,"),
401403
(temporalMemory == null ? "" : "TemporalMemory,"),
402-
(autoCreateClassifiers == null ? "" : "Auto creating CLAClassifiers for each input field."),
404+
(autoCreateClassifiers == null ? "" : "Auto creating Classifiers for each input field."),
403405
(anomalyComputer == null ? "" : "Anomaly"));
404406
}
405407
}
@@ -1048,7 +1050,7 @@ public void start() {
10481050
/**
10491051
* Restarts this {@code Layer}
10501052
*
1051-
* {@link #restart()} is to be called after a call to {@link #halt()}, to begin
1053+
* {@link #restart} is to be called after a call to {@link #halt()}, to begin
10521054
* processing again. The {@link Network} will continue from where it previously
10531055
* left off after the last call to halt().
10541056
*
@@ -1180,7 +1182,7 @@ public Set<Cell> getPredictiveCells() {
11801182
}
11811183

11821184
/**
1183-
* Returns the previous predictive {@link Cells}
1185+
* Returns the previous predictive {@link Cell}s
11841186
*
11851187
* @return the binary vector representing the current prediction.
11861188
*/
@@ -1472,7 +1474,7 @@ void notifyError(Exception e) {
14721474
* </p>
14731475
* <p>
14741476
* If any algorithms are repeated then {@link Inference}s will
1475-
* <em><b>NOT</b></em> be shared between layers. {@link Regions}
1477+
* <em><b>NOT</b></em> be shared between layers. {@link Region}s
14761478
* <em><b>NEVER</b></em> share {@link Inference}s
14771479
* </p>
14781480
*
@@ -1657,7 +1659,7 @@ private Observable<ManualInput> resolveObservableSequence(T t) {
16571659

16581660
/**
16591661
* Executes the check point logic, handles the return of the serialized byte array
1660-
* by delegating the call to {@link rx.Observer#onNext(byte[])} of all the currently queued
1662+
* by delegating the call to {@link rx.Observer#onNext}(byte[]) of all the currently queued
16611663
* Observers; then clears the list of Observers.
16621664
*/
16631665
private void doCheckPoint() {
@@ -1712,7 +1714,15 @@ private void doEncoderBucketMapping(Inference inference, Map<String, Object> enc
17121714
int[] tempArray = new int[e.getWidth()];
17131715
System.arraycopy(encoding, offset, tempArray, 0, tempArray.length);
17141716

1715-
inference.getClassifierInput().put(name, new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" }, name, o, bucketIdx, tempArray));
1717+
inference.getClassifierInput().put(
1718+
name,
1719+
new NamedTuple(
1720+
new String[] { "name", "inputValue", "bucketIdx", "encoding" },
1721+
name,
1722+
o,
1723+
bucketIdx,
1724+
tempArray
1725+
));
17161726
}
17171727
}
17181728

@@ -1798,9 +1808,9 @@ private Observable<ManualInput> fillInOrderedSequence(Observable<ManualInput> o)
17981808

17991809
/**
18001810
* Called internally to create a subscription on behalf of the specified
1801-
* {@link LayerObserver}
1811+
* Layer {@link Observer}
18021812
*
1803-
* @param sub the LayerObserver (subscriber).
1813+
* @param sub the Layer Observer (subscriber).
18041814
* @return
18051815
*/
18061816
private Subscription createSubscription(final Observer<Inference> sub) {
@@ -1908,13 +1918,38 @@ private void clearSubscriberObserverLists() {
19081918
* @param encoder
19091919
* @return
19101920
*/
1921+
@SuppressWarnings("unchecked")
19111922
NamedTuple makeClassifiers(MultiEncoder encoder) {
1923+
Map<String, Class<? extends Classifier>> inferredFields = (Map<String, Class<? extends Classifier>>) params.get(KEY.INFERRED_FIELDS);
1924+
if(inferredFields == null || inferredFields.entrySet().size() == 0) {
1925+
throw new IllegalStateException(
1926+
"KEY.AUTO_CLASSIFY has been set to \"true\", but KEY.INFERRED_FIELDS is null or\n\t" +
1927+
"empty. Must specify desired Classifier for at least one input field in\n\t" +
1928+
"KEY.INFERRED_FIELDS or set KEY.AUTO_CLASSIFY to \"false\" (which is its default\n\t" +
1929+
"value in Parameters)."
1930+
);
1931+
}
19121932
String[] names = new String[encoder.getEncoders(encoder).size()];
1913-
CLAClassifier[] ca = new CLAClassifier[names.length];
1933+
Classifier[] ca = new Classifier[names.length];
19141934
int i = 0;
19151935
for(EncoderTuple et : encoder.getEncoders(encoder)) {
19161936
names[i] = et.getName();
1917-
ca[i] = new CLAClassifier();
1937+
Object fieldClassifier = inferredFields.get(et.getName());
1938+
if(fieldClassifier == CLAClassifier.class) {
1939+
LOGGER.info("Classifying \"" + et.getName() + "\" input field with CLAClassifier");
1940+
ca[i] = new CLAClassifier();
1941+
} else if(fieldClassifier == SDRClassifier.class) {
1942+
LOGGER.info("Classifying \"" + et.getName() + "\" input field with SDRClassifier");
1943+
ca[i] = new SDRClassifier();
1944+
} else if(fieldClassifier != null) {
1945+
throw new IllegalStateException(
1946+
"Invalid Classifier class token, \"" + fieldClassifier + "\",\n\t" +
1947+
"specified for, \"" + et.getName() + "\", input field.\n\t" +
1948+
"Valid class tokens are CLAClassifier.class and SDRClassifier.class"
1949+
);
1950+
} else { // fieldClassifier is null
1951+
LOGGER.info("Not classifying \"" + et.getName() + "\" input field");
1952+
}
19181953
i++;
19191954
}
19201955
return new NamedTuple(names, (Object[])ca);
@@ -2014,8 +2049,7 @@ public void run() {
20142049
* that stores the state of this {@code Network} while keeping the Network up and running.
20152050
* The Network will be stored at the pre-configured location (in binary form only, not JSON).
20162051
*
2017-
* @param network the {@link Network} to check point.
2018-
* @return the {@link CheckPointOp} operator
2052+
* @return the {@link CheckPointOp} operator
20192053
*/
20202054
@SuppressWarnings("unchecked")
20212055
CheckPointOp<byte[]> getCheckPointOperator() {
@@ -2328,10 +2362,13 @@ public ManualInput call(ManualInput t1) {
23282362
bucketIdx = inputs.get("bucketIdx");
23292363
actValue = inputs.get("inputValue");
23302364

2331-
CLAClassifier c = (CLAClassifier)t1.getClassifiers().get(key);
2332-
Classification<Object> result = c.compute(recordNum, inputMap, t1.getSDR(), isLearn, true);
2365+
Classifier c = (Classifier)t1.getClassifiers().get(key);
23332366

2334-
t1.recordNum(recordNum).storeClassification((String)inputs.get("name"), result);
2367+
// c will be null if no classifier was specified for this field in KEY.INFERRED_FIELDS map
2368+
if(c != null) {
2369+
Classification<Object> result = c.compute(recordNum, inputMap, t1.getSDR(), isLearn, true);
2370+
t1.recordNum(recordNum).storeClassification((String)inputs.get("name"), result);
2371+
}
23352372
}
23362373

23372374
return t1;

src/main/java/org/numenta/nupic/network/ManualInput.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import java.util.Map;
2828
import java.util.Set;
2929

30-
import org.numenta.nupic.algorithms.CLAClassifier;
30+
import org.numenta.nupic.algorithms.Classifier;
3131
import org.numenta.nupic.algorithms.Classification;
3232
import org.numenta.nupic.algorithms.SpatialPooler;
3333
import org.numenta.nupic.algorithms.TemporalMemory;
@@ -191,7 +191,9 @@ public ManualInput customObject(Object o) {
191191

192192
/**
193193
* <p>
194-
* Returns the {@link Map} used as input into the {@link CLAClassifier}
194+
* Returns the {@link Map} used as input into the field's {@link Classifier}
195+
* (it is only actually used as input if a Classifier type has specified for
196+
* the field).
195197
*
196198
* This mapping contains the name of the field being classified mapped
197199
* to a {@link NamedTuple} containing:
@@ -237,7 +239,7 @@ public ManualInput classifiers(NamedTuple tuple) {
237239

238240
/**
239241
* Returns a {@link NamedTuple} keyed to the input field
240-
* names, whose values are the {@link CLAClassifier} used
242+
* names, whose values are the {@link Classifier} used
241243
* to track the classification of a particular field
242244
*/
243245
@Override
@@ -341,10 +343,12 @@ ManualInput copy() {
341343
* Returns the most recent {@link Classification}
342344
*
343345
* @param fieldName
344-
* @return
346+
* @return the most recent {@link Classification}, or null if none exists.
345347
*/
346348
@Override
347349
public Classification<Object> getClassification(String fieldName) {
350+
if(classification == null)
351+
return null;
348352
return classification.get(fieldName);
349353
}
350354

0 commit comments

Comments
 (0)