Skip to content

Commit ac49318

Browse files
committed
Merge pull request #409 from addos/fix_layer_test
Removing malformed test from LayerTest
2 parents e17b65e + ad4353c commit ac49318

File tree

1 file changed

+97
-55
lines changed

1 file changed

+97
-55
lines changed

src/test/java/org/numenta/nupic/network/LayerTest.java

Lines changed: 97 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,22 @@
2121
*/
2222
package org.numenta.nupic.network;
2323

24-
import ch.qos.logback.classic.Level;
25-
import ch.qos.logback.classic.Logger;
26-
import ch.qos.logback.classic.LoggerContext;
27-
import ch.qos.logback.classic.turbo.TurboFilter;
28-
import ch.qos.logback.core.spi.FilterReply;
29-
import ch.qos.logback.core.util.StatusPrinter;
24+
import static org.junit.Assert.assertEquals;
25+
import static org.junit.Assert.assertFalse;
26+
import static org.junit.Assert.assertNotNull;
27+
import static org.junit.Assert.assertTrue;
28+
import static org.junit.Assert.fail;
29+
import static org.numenta.nupic.algorithms.Anomaly.KEY_MODE;
30+
import static org.numenta.nupic.algorithms.Anomaly.KEY_USE_MOVING_AVG;
31+
import static org.numenta.nupic.algorithms.Anomaly.KEY_WINDOW_SIZE;
32+
33+
import java.io.File;
34+
import java.util.ArrayList;
35+
import java.util.Arrays;
36+
import java.util.HashMap;
37+
import java.util.List;
38+
import java.util.Map;
39+
3040
import org.junit.Test;
3141
import org.numenta.nupic.Parameters;
3242
import org.numenta.nupic.Parameters.KEY;
@@ -38,6 +48,7 @@
3848
import org.numenta.nupic.algorithms.TemporalMemory;
3949
import org.numenta.nupic.datagen.ResourceLocator;
4050
import org.numenta.nupic.encoders.MultiEncoder;
51+
import org.numenta.nupic.network.Layer.FunctionFactory;
4152
import org.numenta.nupic.network.sensor.FileSensor;
4253
import org.numenta.nupic.network.sensor.HTMSensor;
4354
import org.numenta.nupic.network.sensor.ObservableSensor;
@@ -49,26 +60,18 @@
4960
import org.numenta.nupic.util.MersenneTwister;
5061
import org.slf4j.LoggerFactory;
5162
import org.slf4j.Marker;
63+
64+
import ch.qos.logback.classic.Level;
65+
import ch.qos.logback.classic.Logger;
66+
import ch.qos.logback.classic.LoggerContext;
67+
import ch.qos.logback.classic.turbo.TurboFilter;
68+
import ch.qos.logback.core.spi.FilterReply;
69+
import ch.qos.logback.core.util.StatusPrinter;
5270
import rx.Observable;
5371
import rx.Observer;
5472
import rx.Subscriber;
5573
import rx.functions.Func1;
56-
57-
import java.io.File;
58-
import java.util.ArrayList;
59-
import java.util.Arrays;
60-
import java.util.HashMap;
61-
import java.util.List;
62-
import java.util.Map;
63-
64-
import static org.junit.Assert.assertEquals;
65-
import static org.junit.Assert.assertFalse;
66-
import static org.junit.Assert.assertNotNull;
67-
import static org.junit.Assert.assertTrue;
68-
import static org.junit.Assert.fail;
69-
import static org.numenta.nupic.algorithms.Anomaly.KEY_MODE;
70-
import static org.numenta.nupic.algorithms.Anomaly.KEY_USE_MOVING_AVG;
71-
import static org.numenta.nupic.algorithms.Anomaly.KEY_WINDOW_SIZE;
74+
import rx.subjects.PublishSubject;
7275

7376
/**
7477
* Tests the "heart and soul" of the Network API
@@ -1417,39 +1420,6 @@ public void testInferInputDimensions() {
14171420
assertTrue(Arrays.equals(new int[] { 1, 450 }, dims));
14181421
}
14191422

1420-
String filterMessage = null;
1421-
@Test
1422-
public void testExplicitCloseFailure() {
1423-
Parameters p = NetworkTestHarness.getParameters();
1424-
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
1425-
p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42));
1426-
1427-
Network network = Network.create("test network", p)
1428-
.add(Network.createRegion("r1")
1429-
.add(Network.createLayer("2", p)
1430-
.add(Anomaly.create())
1431-
.add(new SpatialPooler())
1432-
.close()));
1433-
1434-
// Set up a log filter to grab the next message.
1435-
LoggerContext lc = (LoggerContext) LoggerFactory.getILoggerFactory();
1436-
StatusPrinter.print(lc);
1437-
lc.addTurboFilter(new TurboFilter() {
1438-
@Override
1439-
public FilterReply decide(Marker arg0, Logger arg1, Level arg2, String arg3, Object[] arg4, Throwable arg5) {
1440-
filterMessage = arg3;
1441-
return FilterReply.ACCEPT;
1442-
}
1443-
});
1444-
1445-
network.lookup("r1").lookup("2").close();
1446-
1447-
// Test that the close() method exited after logging the correct message
1448-
assertEquals("Close called on Layer r1:2 which is already closed.", filterMessage);
1449-
// Make sure not to slow the entire test phase down by removing the filter
1450-
lc.resetTurboFilterList();
1451-
}
1452-
14531423
@Test(expected = IllegalStateException.class)
14541424
public void isClosedAddSensorTest() {
14551425
Parameters p = NetworkTestHarness.getParameters();
@@ -1490,4 +1460,76 @@ public void isClosedAddSpatialPoolerTest() {
14901460
l.add(new SpatialPooler());
14911461
}
14921462

1463+
@Test
1464+
public void testProperConstructionUsingNonFluentConstructor() {
1465+
try {
1466+
new Layer<>(null, null, null, null, null, null);
1467+
fail();
1468+
}catch(Exception e) {
1469+
assertEquals(IllegalArgumentException.class, e.getClass());
1470+
assertEquals("No parameters specified.", e.getMessage());
1471+
}
1472+
1473+
Parameters p = NetworkTestHarness.getParameters();
1474+
p.setParameterByKey(KEY.FIELD_ENCODING_MAP, null);
1475+
try {
1476+
new Layer<>(p, MultiEncoder.builder().build(), null, null, null, null);
1477+
fail();
1478+
}catch(Exception e) {
1479+
assertEquals(IllegalArgumentException.class, e.getClass());
1480+
assertEquals("The passed in Parameters must contain a field encoding map specified by " +
1481+
"org.numenta.nupic.Parameters.KEY.FIELD_ENCODING_MAP", e.getMessage());
1482+
}
1483+
}
1484+
1485+
@Test
1486+
public void testNullSubscriber() {
1487+
Parameters p = NetworkTestHarness.getParameters();
1488+
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
1489+
p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42));
1490+
1491+
Layer<?> l = Network.createLayer("l", p);
1492+
1493+
try {
1494+
l.subscribe(null);
1495+
fail();
1496+
}catch(Exception e) {
1497+
assertEquals(IllegalArgumentException.class, e.getClass());
1498+
assertEquals("Subscriber cannot be null.", e.getMessage());
1499+
}
1500+
}
1501+
1502+
@SuppressWarnings({ "rawtypes", "unchecked" })
1503+
@Test
1504+
public void testStringToInferenceTransformer() {
1505+
Parameters p = NetworkTestHarness.getParameters();
1506+
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
1507+
p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42));
1508+
1509+
Layer<?> l = Network.createLayer("l", p);
1510+
FunctionFactory ff = l.new FunctionFactory();
1511+
PublishSubject publisher = PublishSubject.create();
1512+
Observable obs = ff.createEncoderFunc(publisher);
1513+
1514+
String[] sa = { "42" };
1515+
1516+
obs.subscribe(new Observer() {
1517+
@Override public void onCompleted() { }
1518+
@Override public void onError(Throwable arg0) { }
1519+
@Override public void onNext(Object arg0) {
1520+
// System.out.println("here");
1521+
}
1522+
1523+
});
1524+
1525+
assertEquals(0, ff.inference.getRecordNum());
1526+
1527+
publisher.onNext(sa);
1528+
1529+
assertEquals("[42]", (Arrays.toString((int[])ff.inference.getLayerInput())));
1530+
assertEquals(-1, ff.inference.getRecordNum()); // Record number gets set by the Layer which hasn't
1531+
// Received a record yet.
1532+
assertEquals("[42]", (Arrays.toString((int[])ff.inference.getSDR())));
1533+
}
1534+
14931535
}

0 commit comments

Comments
 (0)