|
36 | 36 | import org.numenta.nupic.FieldMetaType; |
37 | 37 | import org.numenta.nupic.Parameters; |
38 | 38 | import org.numenta.nupic.Parameters.KEY; |
39 | | -import org.numenta.nupic.algorithms.Anomaly; |
40 | | -import org.numenta.nupic.algorithms.CLAClassifier; |
41 | 39 | import org.numenta.nupic.algorithms.Classification; |
42 | | -import org.numenta.nupic.algorithms.SpatialPooler; |
43 | 40 | import org.numenta.nupic.algorithms.TemporalMemory; |
| 41 | +import org.numenta.nupic.algorithms.SpatialPooler; |
| 42 | +import org.numenta.nupic.algorithms.Anomaly; |
| 43 | +import org.numenta.nupic.algorithms.Classifier; |
| 44 | +import org.numenta.nupic.algorithms.SDRClassifier; |
| 45 | +import org.numenta.nupic.algorithms.CLAClassifier; |
44 | 46 | import org.numenta.nupic.encoders.DateEncoder; |
45 | 47 | import org.numenta.nupic.encoders.Encoder; |
46 | 48 | import org.numenta.nupic.encoders.EncoderTuple; |
@@ -231,7 +233,7 @@ public class Layer<T> implements Persistable { |
231 | 233 | private boolean hasGenericProcess; |
232 | 234 |
|
233 | 235 | /** |
234 | | - * List of {@link Encoders} used when storing bucket information see |
| 236 | + * List of {@link Encoder}s used when storing bucket information see |
235 | 237 | * {@link #doEncoderBucketMapping(Inference, Map)} |
236 | 238 | */ |
237 | 239 | private List<EncoderTuple> encoderTuples; |
@@ -399,7 +401,7 @@ public Layer(Parameters params, MultiEncoder e, SpatialPooler sp, TemporalMemory |
399 | 401 | (encoder == null ? "" : "MultiEncoder,"), |
400 | 402 | (spatialPooler == null ? "" : "SpatialPooler,"), |
401 | 403 | (temporalMemory == null ? "" : "TemporalMemory,"), |
402 | | - (autoCreateClassifiers == null ? "" : "Auto creating CLAClassifiers for each input field."), |
| 404 | + (autoCreateClassifiers == null ? "" : "Auto creating Classifiers for each input field."), |
403 | 405 | (anomalyComputer == null ? "" : "Anomaly")); |
404 | 406 | } |
405 | 407 | } |
@@ -1048,7 +1050,7 @@ public void start() { |
1048 | 1050 | /** |
1049 | 1051 | * Restarts this {@code Layer} |
1050 | 1052 | * |
1051 | | - * {@link #restart()} is to be called after a call to {@link #halt()}, to begin |
| 1053 | + * {@link #restart} is to be called after a call to {@link #halt()}, to begin |
1052 | 1054 | * processing again. The {@link Network} will continue from where it previously |
1053 | 1055 | * left off after the last call to halt(). |
1054 | 1056 | * |
@@ -1180,7 +1182,7 @@ public Set<Cell> getPredictiveCells() { |
1180 | 1182 | } |
1181 | 1183 |
|
1182 | 1184 | /** |
1183 | | - * Returns the previous predictive {@link Cells} |
| 1185 | + * Returns the previous predictive {@link Cell}s |
1184 | 1186 | * |
1185 | 1187 | * @return the binary vector representing the current prediction. |
1186 | 1188 | */ |
@@ -1472,7 +1474,7 @@ void notifyError(Exception e) { |
1472 | 1474 | * </p> |
1473 | 1475 | * <p> |
1474 | 1476 | * If any algorithms are repeated then {@link Inference}s will |
1475 | | - * <em><b>NOT</b></em> be shared between layers. {@link Regions} |
| 1477 | + * <em><b>NOT</b></em> be shared between layers. {@link Region}s |
1476 | 1478 | * <em><b>NEVER</b></em> share {@link Inference}s |
1477 | 1479 | * </p> |
1478 | 1480 | * |
@@ -1657,7 +1659,7 @@ private Observable<ManualInput> resolveObservableSequence(T t) { |
1657 | 1659 |
|
1658 | 1660 | /** |
1659 | 1661 | * Executes the check point logic, handles the return of the serialized byte array |
1660 | | - * by delegating the call to {@link rx.Observer#onNext(byte[])} of all the currently queued |
| 1662 | + * by delegating the call to {@link rx.Observer#onNext}(byte[]) of all the currently queued |
1661 | 1663 | * Observers; then clears the list of Observers. |
1662 | 1664 | */ |
1663 | 1665 | private void doCheckPoint() { |
@@ -1712,7 +1714,15 @@ private void doEncoderBucketMapping(Inference inference, Map<String, Object> enc |
1712 | 1714 | int[] tempArray = new int[e.getWidth()]; |
1713 | 1715 | System.arraycopy(encoding, offset, tempArray, 0, tempArray.length); |
1714 | 1716 |
|
1715 | | - inference.getClassifierInput().put(name, new NamedTuple(new String[] { "name", "inputValue", "bucketIdx", "encoding" }, name, o, bucketIdx, tempArray)); |
| 1717 | + inference.getClassifierInput().put( |
| 1718 | + name, |
| 1719 | + new NamedTuple( |
| 1720 | + new String[] { "name", "inputValue", "bucketIdx", "encoding" }, |
| 1721 | + name, |
| 1722 | + o, |
| 1723 | + bucketIdx, |
| 1724 | + tempArray |
| 1725 | + )); |
1716 | 1726 | } |
1717 | 1727 | } |
1718 | 1728 |
|
@@ -1798,9 +1808,9 @@ private Observable<ManualInput> fillInOrderedSequence(Observable<ManualInput> o) |
1798 | 1808 |
|
1799 | 1809 | /** |
1800 | 1810 | * Called internally to create a subscription on behalf of the specified |
1801 | | - * {@link LayerObserver} |
| 1811 | + * Layer {@link Observer} |
1802 | 1812 | * |
1803 | | - * @param sub the LayerObserver (subscriber). |
| 1813 | + * @param sub the Layer Observer (subscriber). |
1804 | 1814 | * @return |
1805 | 1815 | */ |
1806 | 1816 | private Subscription createSubscription(final Observer<Inference> sub) { |
@@ -1908,13 +1918,38 @@ private void clearSubscriberObserverLists() { |
1908 | 1918 | * @param encoder |
1909 | 1919 | * @return |
1910 | 1920 | */ |
| 1921 | + @SuppressWarnings("unchecked") |
1911 | 1922 | NamedTuple makeClassifiers(MultiEncoder encoder) { |
| 1923 | + Map<String, Class<? extends Classifier>> inferredFields = (Map<String, Class<? extends Classifier>>) params.get(KEY.INFERRED_FIELDS); |
| 1924 | + if(inferredFields == null || inferredFields.entrySet().size() == 0) { |
| 1925 | + throw new IllegalStateException( |
| 1926 | + "KEY.AUTO_CLASSIFY has been set to \"true\", but KEY.INFERRED_FIELDS is null or\n\t" + |
| 1927 | + "empty. Must specify desired Classifier for at least one input field in\n\t" + |
| 1928 | + "KEY.INFERRED_FIELDS or set KEY.AUTO_CLASSIFY to \"false\" (which is its default\n\t" + |
| 1929 | + "value in Parameters)." |
| 1930 | + ); |
| 1931 | + } |
1912 | 1932 | String[] names = new String[encoder.getEncoders(encoder).size()]; |
1913 | | - CLAClassifier[] ca = new CLAClassifier[names.length]; |
| 1933 | + Classifier[] ca = new Classifier[names.length]; |
1914 | 1934 | int i = 0; |
1915 | 1935 | for(EncoderTuple et : encoder.getEncoders(encoder)) { |
1916 | 1936 | names[i] = et.getName(); |
1917 | | - ca[i] = new CLAClassifier(); |
| 1937 | + Object fieldClassifier = inferredFields.get(et.getName()); |
| 1938 | + if(fieldClassifier == CLAClassifier.class) { |
| 1939 | + LOGGER.info("Classifying \"" + et.getName() + "\" input field with CLAClassifier"); |
| 1940 | + ca[i] = new CLAClassifier(); |
| 1941 | + } else if(fieldClassifier == SDRClassifier.class) { |
| 1942 | + LOGGER.info("Classifying \"" + et.getName() + "\" input field with SDRClassifier"); |
| 1943 | + ca[i] = new SDRClassifier(); |
| 1944 | + } else if(fieldClassifier != null) { |
| 1945 | + throw new IllegalStateException( |
| 1946 | + "Invalid Classifier class token, \"" + fieldClassifier + "\",\n\t" + |
| 1947 | + "specified for, \"" + et.getName() + "\", input field.\n\t" + |
| 1948 | + "Valid class tokens are CLAClassifier.class and SDRClassifier.class" |
| 1949 | + ); |
| 1950 | + } else { // fieldClassifier is null |
| 1951 | + LOGGER.info("Not classifying \"" + et.getName() + "\" input field"); |
| 1952 | + } |
1918 | 1953 | i++; |
1919 | 1954 | } |
1920 | 1955 | return new NamedTuple(names, (Object[])ca); |
@@ -2014,8 +2049,7 @@ public void run() { |
2014 | 2049 | * that stores the state of this {@code Network} while keeping the Network up and running. |
2015 | 2050 | * The Network will be stored at the pre-configured location (in binary form only, not JSON). |
2016 | 2051 | * |
2017 | | - * @param network the {@link Network} to check point. |
2018 | | - * @return the {@link CheckPointOp} operator |
| 2052 | + * @return the {@link CheckPointOp} operator |
2019 | 2053 | */ |
2020 | 2054 | @SuppressWarnings("unchecked") |
2021 | 2055 | CheckPointOp<byte[]> getCheckPointOperator() { |
@@ -2328,10 +2362,13 @@ public ManualInput call(ManualInput t1) { |
2328 | 2362 | bucketIdx = inputs.get("bucketIdx"); |
2329 | 2363 | actValue = inputs.get("inputValue"); |
2330 | 2364 |
|
2331 | | - CLAClassifier c = (CLAClassifier)t1.getClassifiers().get(key); |
2332 | | - Classification<Object> result = c.compute(recordNum, inputMap, t1.getSDR(), isLearn, true); |
| 2365 | + Classifier c = (Classifier)t1.getClassifiers().get(key); |
2333 | 2366 |
|
2334 | | - t1.recordNum(recordNum).storeClassification((String)inputs.get("name"), result); |
| 2367 | + // c will be null if no classifier was specified for this field in KEY.INFERRED_FIELDS map |
| 2368 | + if(c != null) { |
| 2369 | + Classification<Object> result = c.compute(recordNum, inputMap, t1.getSDR(), isLearn, true); |
| 2370 | + t1.recordNum(recordNum).storeClassification((String)inputs.get("name"), result); |
| 2371 | + } |
2335 | 2372 | } |
2336 | 2373 |
|
2337 | 2374 | return t1; |
|
0 commit comments