Skip to content

Commit 1aca435

Browse files
authored
Merge pull request #747 from htm-community/pickled_network2
Continuation of fixes for Pickled networkAPI
2 parents cb5e0f1 + d83196d commit 1aca435

File tree

17 files changed

+242
-230
lines changed

17 files changed

+242
-230
lines changed

bindings/py/cpp_src/bindings/engine/py_Engine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ namespace htm_ext
445445
py_Network.def("save", &htm::Network::save)
446446
.def("load", &htm::Network::load)
447447
.def("saveToFile", &htm::Network::saveToFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY)
448-
.def("loadFromFile", &htm::Network::loadFromFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY);
448+
.def("loadFromFile", &htm::Network::loadFromFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY)
449+
.def("__eq__", &htm::Network::operator==);
449450

450451
py_Network.def(py::pickle(
451452
[](const Network& self) {

bindings/py/cpp_src/plugin/PyBindRegion.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ namespace py = pybind11;
291291
, className_(className)
292292

293293
{
294+
// Make a local copy of the Spec
295+
createSpec(module_.c_str(), nodeSpec_, className_.c_str());
294296

295297
cereal_adapter_load(wrapper);
296298
}

bindings/py/cpp_src/plugin/PyBindRegion.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ namespace htm
131131

132132

133133
private:
134-
std::string module_;
135-
std::string className_;
134+
std::string module_; // Full path to the class.
135+
std::string className_; // Just the name of the class.
136136

137137
pybind11::object node_;
138138

bindings/py/cpp_src/plugin/RegisteredRegionImplPy.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ namespace htm
168168
* when its name is used in a Network::addRegion() call.
169169
*
170170
* @param className -- the name of the Python class that implements the region.
171-
* @param module -- the module (shared library) in which the class resides.
171+
* @param module -- the module (full path and file) in which the class resides.
172172
*/
173173
inline static void registerPyRegion(const std::string& module, const std::string& className) {
174174
std::string nodeType = "py." + className;

bindings/py/packaging/src/htm/bindings/regions/PyRegion.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
import collections
2525

2626
RealNumpyDType = numpy.float32
27-
from abc import ABCMeta, abstractmethod
28-
29-
class DictReadOnlyWrapper(collections.Mapping):
27+
if sys.version > '3':
28+
from abc import ABC,abstractmethod
29+
# see http://www.programmersought.com/article/7351237937/
30+
else:
31+
from abc import ABCMeta, abstractmethod
32+
class DictReadOnlyWrapper(collections.abc.Mapping):
3033
"""
3134
Provides read-only access to a dict. When dict items are mutable, they can
3235
still be mutated in-place, but dict items can't be reassigned.
@@ -44,10 +47,10 @@ def __len__(self):
4447
def __getitem__(self, key):
4548
return self._d[key]
4649

47-
if sys.version_info[0] >= 3:
50+
if sys.version > '3':
4851
# Compile the metaclass at runtime because it's invalid python2 syntax.
4952
exec("""
50-
class _PyRegionMeta(object, metaclass=ABCMeta):
53+
class _PyRegionMeta(object):
5154
pass
5255
""")
5356
else:

bindings/py/tests/regions/pyregion_test.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,7 @@ def __init__(self):
5050
class PyRegionTest(unittest.TestCase):
5151

5252

53-
def testNoInit(self):
54-
"""Test unimplemented init method"""
55-
class NoInit(PyRegion):
56-
pass
5753

58-
with self.assertRaises(TypeError) as cw:
59-
_ni = NoInit()
60-
61-
self.assertEqual(str(cw.exception), "Can't instantiate abstract class " +
62-
"NoInit with abstract methods __init__, compute, initialize")
63-
64-
65-
def testUnimplementedAbstractMethods(self):
66-
"""Test unimplemented abstract methods"""
67-
# Test unimplemented getSpec (results in NotImplementedError)
68-
with self.assertRaises(NotImplementedError):
69-
X.getSpec()
70-
71-
# Test unimplemented abstract methods (x can't be instantiated)
72-
with self.assertRaises(TypeError) as cw:
73-
_x = X()
74-
75-
self.assertEqual(str(cw.exception), "Can't instantiate abstract class " +
76-
"X with abstract methods compute, initialize")
7754

7855
def testUnimplementedNotImplementedMethods(self):
7956
"""Test unimplemented @not_implemented methods"""
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
'''
2+
Created on 6 Nov 2019
3+
4+
@author: fred
5+
'''
6+
import unittest
7+
8+
import pickle
9+
import sys
10+
import numpy as np
11+
12+
from htm.advanced.support.register_regions import registerAllAdvancedRegions
13+
from htm.bindings.engine_internal import Network
14+
from htm.advanced.frameworks.location.location_network_creation import createL4L6aLocationColumn
15+
16+
cortical_params = {
17+
# Column parameters
18+
19+
# L2 Parameters
20+
# Adapted from htmresearch.frameworks.layers.l2_l4_inference.L4L2Experiment#getDefaultL2Params
21+
# L4 Parameters
22+
"l4_cellsPerColumn": 24,
23+
"l4_columnCount": 16,
24+
"l4_connectedPermanence": 0.6,
25+
"l4_permanenceIncrement": 0.1,
26+
"l4_permanenceDecrement": 0.02,
27+
"l4_apicalPredictedSegmentDecrement": 0.0,
28+
"l4_basalPredictedSegmentDecrement": 0.0,
29+
"l4_initialPermanence": 1.0,
30+
"l4_activationThreshold": 8,
31+
"l4_minThreshold": 8,
32+
"l4_reducedBasalThreshold": 8,
33+
"l4_sampleSize": 10,
34+
"l4_implementation": "ApicalTiebreak",
35+
36+
# L6a Parameters
37+
"l6a_moduleCount": 10,
38+
"l6a_dimensions": 2,
39+
"l6a_connectedPermanence": 0.5,
40+
"l6a_permanenceIncrement": 0.1,
41+
"l6a_permanenceDecrement": 0.0,
42+
"l6a_initialPermanence": 1.0,
43+
"l6a_activationThreshold": 8,
44+
"l6a_initialPermanence": 1.0,
45+
"l6a_learningThreshold": 8,
46+
"l6a_sampleSize": 10,
47+
"l6a_cellsPerAxis": 10,
48+
"l6a_scale": 10,
49+
"l6a_orientation": 60,
50+
"l6a_bumpOverlapMethod": "probabilistic"
51+
}
52+
53+
class TestSimpleSPTMNetwork(unittest.TestCase):
54+
55+
def _create_network(self, L4Params, L6aParams):
56+
"""
57+
Constructor.
58+
"""
59+
network = Network()
60+
61+
# Create network
62+
network = createL4L6aLocationColumn(network=network,
63+
L4Params=L4Params,
64+
L6aParams=L6aParams,
65+
inverseReadoutResolution=None,
66+
baselineCellsPerAxis=6,
67+
suffix="")
68+
69+
network.initialize()
70+
return network
71+
72+
def setUp(self):
73+
registerAllAdvancedRegions()
74+
75+
self._params = cortical_params
76+
77+
L4Params = {param.split("_")[1]:value for param, value in self._params.items() if param.startswith("l4")}
78+
L6aParams = {param.split("_")[1]:value for param, value in self._params.items() if param.startswith("l6a")}
79+
# Configure L6a self._htm_parameters
80+
numModules = L6aParams["moduleCount"]
81+
L6aParams["scale"] = [L6aParams["scale"]] * numModules
82+
angle = L6aParams["orientation"] // numModules
83+
orientation = list(range(angle // 2, angle * numModules, angle))
84+
L6aParams["orientation"] = np.radians(orientation).tolist()
85+
86+
self.network = self._create_network(L4Params, L6aParams)
87+
88+
def tearDown(self):
89+
self.network = None
90+
91+
def _run_network(self, network):
92+
"""
93+
Run the network with fixed data.
94+
"""
95+
motorInput = network.getRegion("motorInput")
96+
sensorInput = network.getRegion("sensorInput")
97+
motorInput.executeCommand('addDataToQueue', [0,0])
98+
sensorInput.executeCommand('addDataToQueue', [1,2,3], False, 0)
99+
100+
network.run(1)
101+
102+
L4Region = network.getRegion("L4")
103+
activeL4Cells = np.array(L4Region.getOutputArray("activeCells")).nonzero()[0]
104+
105+
return activeL4Cells
106+
107+
def testAL246aCorticalColumnPickle(self):
108+
"""
109+
Test that L246aCorticalColumn can be pickled.
110+
"""
111+
if sys.version_info[0] >= 3:
112+
proto = 3
113+
else:
114+
proto = 2
115+
# Simple test: make sure that dumping / loading works...
116+
pickledColumn = pickle.dumps(self.network, proto)
117+
network2 = pickle.loads(pickledColumn)
118+
s1 = self._run_network(self.network)
119+
s2 = self._run_network(network2)
120+
self.assertTrue(np.array_equal(s1, s2))
121+
122+
123+
if __name__ == "__main__":
124+
#import sys;sys.argv = ['', 'Test.testName']
125+
unittest.main()

setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from setuptools.command.test import test as BaseTestCommand
3030
# see https://stackoverflow.com/questions/44323474/distutils-core-vs-setuptools-with-c-extension
3131
from setuptools import Extension
32+
from pathlib import Path
33+
from sys import version_info
3234

3335
# NOTE: To debug the python bindings in a debugger, use the procedure
3436
# described here: https://pythonextensionpatterns.readthedocs.io/en/latest/debugging/debug_in_ide.html
@@ -264,6 +266,11 @@ def generateExtensions(platform, build_type):
264266
and then create the extension libraries in Repository/build/Release/distr/src/nupic/bindings.
265267
Note: for Windows it will force a X64 build.
266268
"""
269+
270+
# Make sure that .py code gets copied during any build so it is included in .whl
271+
if version_info >= (3, 4):
272+
Path(os.path.join(PY_BINDINGS, 'cpp_src', 'CMakeLists.txt')).touch()
273+
267274
cwd = os.getcwd()
268275
scriptsDir = os.path.join(REPO_DIR, "build", "scripts")
269276
try:
@@ -290,7 +297,6 @@ def configure(platform, build_type):
290297
cwd = os.getcwd()
291298

292299
print("Python version: {}\n".format(sys.version))
293-
from sys import version_info
294300
if version_info > (3, 0):
295301
# Build a Python 3.x library
296302
PY_VER = "-DBINDING_BUILD=Python3"

src/htm/engine/Input.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ void Input::addLink(const std::shared_ptr<Link> link, std::shared_ptr<Output> sr
5555
// Make sure we don't already have a link to the same output
5656
for (const auto &it : links_) {
5757
const Output* o = (*it).getSrc();
58-
NTA_CHECK(srcOutput.get() != o) << "addLink -- link from region "
59-
<< srcOutput->getRegion()->getName() << " output "
60-
<< srcOutput->getName() << " to region " << region_->getName()
61-
<< " input " << getName() << " already exists";
58+
NTA_CHECK(srcOutput.get() != o) << "Input::addLink() -- link from output="
59+
<< srcOutput->getRegion()->getName() << "."
60+
<< srcOutput->getName() << " to input=" << region_->getName()
61+
<< "." << getName() << " already exists";
6262
}
6363

6464
links_.push_back(link);

src/htm/engine/Network.cpp

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,11 @@ std::shared_ptr<Region> Network::addRegion(std::shared_ptr<Region>& r) {
109109
NTA_CHECK(r != nullptr);
110110
r->network_ = this;
111111
regions_[r->getName()] = r;
112+
113+
// If a region is added, initially set the phase to the default phase.
114+
// The phase can be changed later.
115+
setDefaultPhase_(r.get());
112116

113-
// We must make a copy of the phases set here because
114-
// setPhases_ will be passing this back down into
115-
// the region.
116-
std::set<UInt32> phases = r->getPhases();
117-
setPhases_(r.get(), phases);
118117
return r;
119118
}
120119

@@ -155,13 +154,11 @@ void Network::setPhases_(Region *r, std::set<UInt32> &phases) {
155154
if (item != phaseInfo_[i].end() && !insertPhase) {
156155
phaseInfo_[i].erase(item);
157156
} else if (insertPhase) {
157+
// add the new phase(s_
158158
phaseInfo_[i].insert(r);
159159
}
160160
}
161161

162-
// keep track (redundantly) of phases inside the Region also, for
163-
// serialization
164-
r->setPhases(phases);
165162

166163
resetEnabledPhases_();
167164
}
@@ -443,12 +440,10 @@ std::shared_ptr<Region> Network::getRegion(const std::string& name) const {
443440
std::vector<std::shared_ptr<Link>> Network::getLinks() const {
444441
std::vector<std::shared_ptr<Link>> links;
445442

446-
for (UInt32 phase = minEnabledPhase_; phase <= maxEnabledPhase_; phase++) {
447-
for (auto r : phaseInfo_[phase]) {
448-
for (auto &input : r->getInputs()) {
449-
for (auto &link : input.second->getLinks()) {
450-
links.push_back(link);
451-
}
443+
for (auto p : regions_) {
444+
for (auto &input : p.second->getInputs()) {
445+
for (auto &link : input.second->getLinks()) {
446+
links.push_back(link);
452447
}
453448
}
454449
}
@@ -521,13 +516,9 @@ void Network::post_load() {
521516
for(auto p: regions_) {
522517
std::shared_ptr<Region>& r = p.second;
523518
r->network_ = this;
524-
std::set<UInt32> phases = r->getPhases();
525-
setPhases_(r.get(), phases);
526519
r->evaluateLinks(); // Create the input buffers.
527520
}
528521

529-
NTA_CHECK(maxEnabledPhase_ < phaseInfo_.size())
530-
<< "maxphase: " << maxEnabledPhase_ << " size: " << phaseInfo_.size();
531522

532523
// Note: When serialized, the output buffers are saved
533524
// by each Region. After restore we need to
@@ -564,6 +555,65 @@ void Network::post_load() {
564555

565556
}
566557

558+
std::string Network::phasesToString() const {
559+
std::stringstream ss;
560+
ss << "{";
561+
ss << "minEnabledPhase_: " << minEnabledPhase_ << ", ";
562+
ss << "maxEnabledPhase_: " << maxEnabledPhase_ << ", ";
563+
ss << "info: [";
564+
for (auto phase : phaseInfo_) {
565+
ss << "[";
566+
for (auto region : phase) {
567+
ss << region->getName() << ", ";
568+
}
569+
ss << "]";
570+
}
571+
ss << "]}";
572+
return ss.str();
573+
}
574+
void Network::phasesFromString(const std::string& phaseString) {
575+
std::string content = phaseString;
576+
content.erase(std::remove(content.begin(), content.end(), ','), content.end());
577+
std::stringstream ss(content);
578+
std::string tag;
579+
std::set<Region *> phase;
580+
581+
NTA_CHECK(ss.peek() == '{') << "Invalid phase deserialization";
582+
ss.ignore(1);
583+
ss >> tag;
584+
NTA_CHECK(tag == "minEnabledPhase_:");
585+
ss >> minEnabledPhase_;
586+
ss >> tag;
587+
NTA_CHECK(tag == "maxEnabledPhase_:");
588+
ss >> maxEnabledPhase_;
589+
ss >> tag;
590+
NTA_CHECK(tag == "info:") << "Invalid phase deserialization";
591+
ss >> std::ws;
592+
NTA_CHECK(ss.peek() == '[') << "Invalid phase deserialization";
593+
ss.ignore(1);
594+
ss >> std::ws;
595+
while (ss.peek() != ']') {
596+
ss >> std::ws;
597+
if (ss.peek() == '[') {
598+
ss.ignore(1);
599+
ss >> std::ws;
600+
while (ss.peek() != ']') {
601+
ss >> tag;
602+
auto it = regions_.find(tag);
603+
NTA_CHECK(it != regions_.end()) << "Region '" << tag << "' not found while decoding phase.";
604+
phase.insert(it->second.get());
605+
ss >> std::ws;
606+
}
607+
ss.ignore(1); // ']'
608+
phaseInfo_.push_back(phase);
609+
phase.clear();
610+
}
611+
}
612+
ss >> std::ws;
613+
ss.ignore(1); // ']'
614+
}
615+
616+
567617
void Network::enableProfiling() {
568618
for (auto p: regions_) {
569619
std::shared_ptr<Region> r = p.second;

0 commit comments

Comments
 (0)