|
21 | 21 | */ |
22 | 22 | package org.numenta.nupic.network; |
23 | 23 |
|
24 | | -import ch.qos.logback.classic.Level; |
25 | | -import ch.qos.logback.classic.Logger; |
26 | | -import ch.qos.logback.classic.LoggerContext; |
27 | | -import ch.qos.logback.classic.turbo.TurboFilter; |
28 | | -import ch.qos.logback.core.spi.FilterReply; |
29 | | -import ch.qos.logback.core.util.StatusPrinter; |
| 24 | +import static org.junit.Assert.assertEquals; |
| 25 | +import static org.junit.Assert.assertFalse; |
| 26 | +import static org.junit.Assert.assertNotNull; |
| 27 | +import static org.junit.Assert.assertTrue; |
| 28 | +import static org.junit.Assert.fail; |
| 29 | +import static org.numenta.nupic.algorithms.Anomaly.KEY_MODE; |
| 30 | +import static org.numenta.nupic.algorithms.Anomaly.KEY_USE_MOVING_AVG; |
| 31 | +import static org.numenta.nupic.algorithms.Anomaly.KEY_WINDOW_SIZE; |
| 32 | + |
| 33 | +import java.io.File; |
| 34 | +import java.util.ArrayList; |
| 35 | +import java.util.Arrays; |
| 36 | +import java.util.HashMap; |
| 37 | +import java.util.List; |
| 38 | +import java.util.Map; |
| 39 | + |
30 | 40 | import org.junit.Test; |
31 | 41 | import org.numenta.nupic.Parameters; |
32 | 42 | import org.numenta.nupic.Parameters.KEY; |
|
38 | 48 | import org.numenta.nupic.algorithms.TemporalMemory; |
39 | 49 | import org.numenta.nupic.datagen.ResourceLocator; |
40 | 50 | import org.numenta.nupic.encoders.MultiEncoder; |
| 51 | +import org.numenta.nupic.network.Layer.FunctionFactory; |
41 | 52 | import org.numenta.nupic.network.sensor.FileSensor; |
42 | 53 | import org.numenta.nupic.network.sensor.HTMSensor; |
43 | 54 | import org.numenta.nupic.network.sensor.ObservableSensor; |
|
49 | 60 | import org.numenta.nupic.util.MersenneTwister; |
50 | 61 | import org.slf4j.LoggerFactory; |
51 | 62 | import org.slf4j.Marker; |
| 63 | + |
| 64 | +import ch.qos.logback.classic.Level; |
| 65 | +import ch.qos.logback.classic.Logger; |
| 66 | +import ch.qos.logback.classic.LoggerContext; |
| 67 | +import ch.qos.logback.classic.turbo.TurboFilter; |
| 68 | +import ch.qos.logback.core.spi.FilterReply; |
| 69 | +import ch.qos.logback.core.util.StatusPrinter; |
52 | 70 | import rx.Observable; |
53 | 71 | import rx.Observer; |
54 | 72 | import rx.Subscriber; |
55 | 73 | import rx.functions.Func1; |
56 | | - |
57 | | -import java.io.File; |
58 | | -import java.util.ArrayList; |
59 | | -import java.util.Arrays; |
60 | | -import java.util.HashMap; |
61 | | -import java.util.List; |
62 | | -import java.util.Map; |
63 | | - |
64 | | -import static org.junit.Assert.assertEquals; |
65 | | -import static org.junit.Assert.assertFalse; |
66 | | -import static org.junit.Assert.assertNotNull; |
67 | | -import static org.junit.Assert.assertTrue; |
68 | | -import static org.junit.Assert.fail; |
69 | | -import static org.numenta.nupic.algorithms.Anomaly.KEY_MODE; |
70 | | -import static org.numenta.nupic.algorithms.Anomaly.KEY_USE_MOVING_AVG; |
71 | | -import static org.numenta.nupic.algorithms.Anomaly.KEY_WINDOW_SIZE; |
| 74 | +import rx.subjects.PublishSubject; |
72 | 75 |
|
73 | 76 | /** |
74 | 77 | * Tests the "heart and soul" of the Network API |
@@ -1417,39 +1420,6 @@ public void testInferInputDimensions() { |
1417 | 1420 | assertTrue(Arrays.equals(new int[] { 1, 450 }, dims)); |
1418 | 1421 | } |
1419 | 1422 |
|
1420 | | - String filterMessage = null; |
1421 | | - @Test |
1422 | | - public void testExplicitCloseFailure() { |
1423 | | - Parameters p = NetworkTestHarness.getParameters(); |
1424 | | - p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams()); |
1425 | | - p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42)); |
1426 | | - |
1427 | | - Network network = Network.create("test network", p) |
1428 | | - .add(Network.createRegion("r1") |
1429 | | - .add(Network.createLayer("2", p) |
1430 | | - .add(Anomaly.create()) |
1431 | | - .add(new SpatialPooler()) |
1432 | | - .close())); |
1433 | | - |
1434 | | - // Set up a log filter to grab the next message. |
1435 | | - LoggerContext lc = (LoggerContext) LoggerFactory.getILoggerFactory(); |
1436 | | - StatusPrinter.print(lc); |
1437 | | - lc.addTurboFilter(new TurboFilter() { |
1438 | | - @Override |
1439 | | - public FilterReply decide(Marker arg0, Logger arg1, Level arg2, String arg3, Object[] arg4, Throwable arg5) { |
1440 | | - filterMessage = arg3; |
1441 | | - return FilterReply.ACCEPT; |
1442 | | - } |
1443 | | - }); |
1444 | | - |
1445 | | - network.lookup("r1").lookup("2").close(); |
1446 | | - |
1447 | | - // Test that the close() method exited after logging the correct message |
1448 | | - assertEquals("Close called on Layer r1:2 which is already closed.", filterMessage); |
1449 | | - // Make sure not to slow the entire test phase down by removing the filter |
1450 | | - lc.resetTurboFilterList(); |
1451 | | - } |
1452 | | - |
1453 | 1423 | @Test(expected = IllegalStateException.class) |
1454 | 1424 | public void isClosedAddSensorTest() { |
1455 | 1425 | Parameters p = NetworkTestHarness.getParameters(); |
@@ -1490,4 +1460,76 @@ public void isClosedAddSpatialPoolerTest() { |
1490 | 1460 | l.add(new SpatialPooler()); |
1491 | 1461 | } |
1492 | 1462 |
|
| 1463 | + @Test |
| 1464 | + public void testProperConstructionUsingNonFluentConstructor() { |
| 1465 | + try { |
| 1466 | + new Layer<>(null, null, null, null, null, null); |
| 1467 | + fail(); |
| 1468 | + }catch(Exception e) { |
| 1469 | + assertEquals(IllegalArgumentException.class, e.getClass()); |
| 1470 | + assertEquals("No parameters specified.", e.getMessage()); |
| 1471 | + } |
| 1472 | + |
| 1473 | + Parameters p = NetworkTestHarness.getParameters(); |
| 1474 | + p.setParameterByKey(KEY.FIELD_ENCODING_MAP, null); |
| 1475 | + try { |
| 1476 | + new Layer<>(p, MultiEncoder.builder().build(), null, null, null, null); |
| 1477 | + fail(); |
| 1478 | + }catch(Exception e) { |
| 1479 | + assertEquals(IllegalArgumentException.class, e.getClass()); |
| 1480 | + assertEquals("The passed in Parameters must contain a field encoding map specified by " + |
| 1481 | + "org.numenta.nupic.Parameters.KEY.FIELD_ENCODING_MAP", e.getMessage()); |
| 1482 | + } |
| 1483 | + } |
| 1484 | + |
| 1485 | + @Test |
| 1486 | + public void testNullSubscriber() { |
| 1487 | + Parameters p = NetworkTestHarness.getParameters(); |
| 1488 | + p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams()); |
| 1489 | + p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42)); |
| 1490 | + |
| 1491 | + Layer<?> l = Network.createLayer("l", p); |
| 1492 | + |
| 1493 | + try { |
| 1494 | + l.subscribe(null); |
| 1495 | + fail(); |
| 1496 | + }catch(Exception e) { |
| 1497 | + assertEquals(IllegalArgumentException.class, e.getClass()); |
| 1498 | + assertEquals("Subscriber cannot be null.", e.getMessage()); |
| 1499 | + } |
| 1500 | + } |
| 1501 | + |
| 1502 | + @SuppressWarnings({ "rawtypes", "unchecked" }) |
| 1503 | + @Test |
| 1504 | + public void testStringToInferenceTransformer() { |
| 1505 | + Parameters p = NetworkTestHarness.getParameters(); |
| 1506 | + p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams()); |
| 1507 | + p.setParameterByKey(KEY.RANDOM, new MersenneTwister(42)); |
| 1508 | + |
| 1509 | + Layer<?> l = Network.createLayer("l", p); |
| 1510 | + FunctionFactory ff = l.new FunctionFactory(); |
| 1511 | + PublishSubject publisher = PublishSubject.create(); |
| 1512 | + Observable obs = ff.createEncoderFunc(publisher); |
| 1513 | + |
| 1514 | + String[] sa = { "42" }; |
| 1515 | + |
| 1516 | + obs.subscribe(new Observer() { |
| 1517 | + @Override public void onCompleted() { } |
| 1518 | + @Override public void onError(Throwable arg0) { } |
| 1519 | + @Override public void onNext(Object arg0) { |
| 1520 | + // System.out.println("here"); |
| 1521 | + } |
| 1522 | + |
| 1523 | + }); |
| 1524 | + |
| 1525 | + assertEquals(0, ff.inference.getRecordNum()); |
| 1526 | + |
| 1527 | + publisher.onNext(sa); |
| 1528 | + |
| 1529 | + assertEquals("[42]", (Arrays.toString((int[])ff.inference.getLayerInput()))); |
| 1530 | + assertEquals(-1, ff.inference.getRecordNum()); // Record number gets set by the Layer which hasn't |
| 1531 | + // Received a record yet. |
| 1532 | + assertEquals("[42]", (Arrays.toString((int[])ff.inference.getSDR()))); |
| 1533 | + } |
| 1534 | + |
1493 | 1535 | } |
0 commit comments