Skip to content

Commit fac5da7

Browse files
committed
Merge branch 'master_community' into cleanup_exceptions
2 parents 2149e83 + 1aca435 commit fac5da7

File tree

22 files changed

+567
-324
lines changed

22 files changed

+567
-324
lines changed

bindings/py/cpp_src/bindings/engine/py_Engine.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,21 @@ namespace htm_ext
268268

269269
r.setInputData(name, s);
270270
});
271+
272+
py_Region.def(py::pickle(
273+
[](const Region& self) {
274+
std::stringstream ss;
275+
self.save(ss);
276+
return py::bytes(ss.str());
277+
},
278+
// Note: a de-serialized Region will need to be reattached to a Network
279+
// before it could be used. See Network::addRegion( Region*)
280+
[](const py::bytes& s) {
281+
std::istringstream ss(s);
282+
Region self;
283+
self.load(ss);
284+
return self;
285+
}));
271286

272287

273288
py_Region.def("getParameterInt32", &Region::getParameterInt32)
@@ -408,6 +423,11 @@ namespace htm_ext
408423
, py::arg("name")
409424
, py::arg("nodeType" )
410425
, py::arg("nodeParams"));
426+
py_Network.def("addRegion", (Region_Ptr_t (htm::Network::*)(
427+
Region_Ptr_t&))
428+
&htm::Network::addRegion,
429+
"add region for deserialization."
430+
, py::arg("region"));
411431

412432
py_Network.def("getRegions", &htm::Network::getRegions)
413433
.def("getRegion", &htm::Network::getRegion)
@@ -426,8 +446,21 @@ namespace htm_ext
426446
py_Network.def("save", &htm::Network::save)
427447
.def("load", &htm::Network::load)
428448
.def("saveToFile", &htm::Network::saveToFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY)
429-
.def("loadFromFile", &htm::Network::loadFromFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY);
449+
.def("loadFromFile", &htm::Network::loadFromFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY)
450+
.def("__eq__", &htm::Network::operator==);
430451

452+
py_Network.def(py::pickle(
453+
[](const Network& self) {
454+
std::stringstream ss;
455+
self.save(ss);
456+
return py::bytes(ss.str());
457+
},
458+
[](const py::bytes& s) {
459+
std::istringstream ss(s);
460+
Network self;
461+
self.load(ss);
462+
return self;
463+
}));
431464

432465
py_Network.def("link", &htm::Network::link
433466
, "Defines a link between regions"

bindings/py/cpp_src/plugin/PyBindRegion.cpp

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ namespace py = pybind11;
291291
, className_(className)
292292

293293
{
294+
// Make a local copy of the Spec
295+
createSpec(module_.c_str(), nodeSpec_, className_.c_str());
294296

295297
cereal_adapter_load(wrapper);
296298
}
@@ -304,30 +306,39 @@ namespace py = pybind11;
304306
// 1. serialize main state using pickle
305307
// 2. call class method to serialize external state
306308

307-
// 1. Serialize main state of the Python module
308-
// We want this to end up in the open stream obtained from bundle.
309-
// a. We first pickle the python into a temporary file.
310-
// b. copy the file into our open stream.
311-
312-
std::string tmp_pickle = "pickle.tmp";
313-
py::tuple args = py::make_tuple(tmp_pickle, "wb");
314-
auto f = py::module::import("__builtin__").attr("file")(*args);
309+
// Serialize main state of the Python module
310+
// We want this to end up in a string that we pass back to Cereal.
311+
// a. We first pickle the python into an in-memory byte stream.
312+
// b. We then convert that to a Base64 std::string that is returned.
313+
//
314+
// Basicly, we are executing the following Python code:
315+
// import io
316+
// import base64
317+
// import pickle
318+
// f = io.BytesIO()
319+
// pickle.dump(node, f, 3)
320+
// b = f.getvalue()
321+
// content = str(base64.b64encode(b))
322+
// f.close()
323+
324+
py::tuple args;
325+
auto f = py::module::import("io").attr("BytesIO")();
315326

316327
#if PY_MAJOR_VERSION >= 3
317328
auto pickle = py::module::import("pickle");
329+
args = py::make_tuple(node_, f, 3); // use type 3 protocol
318330
#else
319331
auto pickle = py::module::import("cPickle");
320-
#endif
321332
args = py::make_tuple(node_, f, 2); // use type 2 protocol
333+
#endif
322334
pickle.attr("dump")(*args);
323-
pickle.attr("close")();
324-
325-
// copy the pickle into the out string
326-
std::ifstream pfile(tmp_pickle.c_str(), std::ios::binary);
327-
std::string content((std::istreambuf_iterator<char>(pfile)),
328-
std::istreambuf_iterator<char>());
329-
pfile.close();
330-
Path::remove(tmp_pickle);
335+
336+
// copy the pickle stream into the content as a base64 encoded utf8 string
337+
py::bytes b = f.attr("getvalue")();
338+
args = py::make_tuple(b);
339+
std::string content = py::str(py::module::import("base64").attr("b64encode")(*args));
340+
341+
f.attr("close")();
331342
return content;
332343
}
333344
std::string PyBindRegion::extraSerialize() const
@@ -354,31 +365,34 @@ namespace py = pybind11;
354365
void PyBindRegion::pickleDeserialize(std::string p) {
355366
// 1. deserialize main state using pickle
356367
// 2. call class method to deserialize external state
357-
358-
std::ofstream des;
359-
std::string tmp_pickle = "pickle.tmp";
360-
361-
362-
std::ofstream pfile(tmp_pickle.c_str(), std::ios::binary);
363-
pfile.write(p.c_str(), p.size());
364-
pfile.close();
365-
366-
367-
// Tell Python to un-pickle using what is now in the pickle.tmp file.
368-
py::args args = py::make_tuple(tmp_pickle, "rb");
369-
auto f = py::module::import("__builtin__").attr("file")(*args);
368+
//
369+
// Tell Python to un-pickle using what is in the string p.
370+
// but first we need to convert the base64 string into bytes.
371+
//
372+
// Basically we are executing the following Python code:
373+
// import base64
374+
// import io
375+
// import pickle
376+
// b = base64.b64decode(bytes(p))
377+
// f = io.BytesIO(b)
378+
// node = pickle.load(f)
379+
// f.close()
380+
381+
py::args args;
382+
args = py::make_tuple(py::bytes(p));
383+
py::bytes b = py::module::import("base64").attr("b64decode")(*args);
384+
args = py::make_tuple(b);
385+
auto f = py::module::import("io").attr("BytesIO")(*args);
370386

371387
#if PY_MAJOR_VERSION >= 3
372388
auto pickle = py::module::import("pickle");
373389
#else
374390
auto pickle = py::module::import("cPickle");
375391
#endif
392+
args = py::make_tuple(f);
393+
node_ = pickle.attr("load")(*args);
376394

377-
args = py::make_tuple(node_, f);
378-
pickle.attr("load")(*args);
379-
380-
pickle.attr("close")();
381-
Path::remove(tmp_pickle);
395+
f.attr("close")();
382396
}
383397

384398
void PyBindRegion::extraDeserialize(std::string e) {

bindings/py/cpp_src/plugin/PyBindRegion.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ namespace htm
131131

132132

133133
private:
134-
std::string module_;
135-
std::string className_;
134+
std::string module_; // Full path to the class.
135+
std::string className_; // Just the name of the class.
136136

137137
pybind11::object node_;
138138

bindings/py/cpp_src/plugin/RegisteredRegionImplPy.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ namespace htm
103103
{
104104
throw Exception(__FILE__, __LINE__, e.what());
105105
}
106+
catch (const py::cast_error& e)
107+
{
108+
throw Exception(__FILE__, __LINE__, e.what());
109+
}
106110
catch (htm::Exception & e)
107111
{
108112
throw htm::Exception(e);
@@ -127,6 +131,10 @@ namespace htm
127131
{
128132
throw Exception(__FILE__, __LINE__, e.what());
129133
}
134+
catch (const py::cast_error& e)
135+
{
136+
throw Exception(__FILE__, __LINE__, e.what());
137+
}
130138
catch (htm::Exception & e)
131139
{
132140
throw htm::Exception(e);
@@ -160,7 +168,7 @@ namespace htm
160168
* when its name is used in a Network::addRegion() call.
161169
*
162170
* @param className -- the name of the Python class that implements the region.
163-
* @param module -- the module (shared library) in which the class resides.
171+
* @param module -- the module (full path and file) in which the class resides.
164172
*/
165173
inline static void registerPyRegion(const std::string& module, const std::string& className) {
166174
std::string nodeType = "py." + className;

bindings/py/packaging/src/htm/bindings/regions/PyRegion.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
import collections
2525

2626
RealNumpyDType = numpy.float32
27-
from abc import ABCMeta, abstractmethod
28-
29-
class DictReadOnlyWrapper(collections.Mapping):
27+
if sys.version > '3':
28+
from abc import ABC,abstractmethod
29+
# see http://www.programmersought.com/article/7351237937/
30+
else:
31+
from abc import ABCMeta, abstractmethod
32+
class DictReadOnlyWrapper(collections.abc.Mapping):
3033
"""
3134
Provides read-only access to a dict. When dict items are mutable, they can
3235
still be mutated in-place, but dict items can't be reassigned.
@@ -44,10 +47,10 @@ def __len__(self):
4447
def __getitem__(self, key):
4548
return self._d[key]
4649

47-
if sys.version_info[0] >= 3:
50+
if sys.version > '3':
4851
# Compile the metaclass at runtime because it's invalid python2 syntax.
4952
exec("""
50-
class _PyRegionMeta(object, metaclass=ABCMeta):
53+
class _PyRegionMeta(object):
5154
pass
5255
""")
5356
else:
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# ----------------------------------------------------------------------
2+
# HTM Community Edition of NuPIC
3+
# Copyright (C) 2015, Numenta, Inc.
4+
#
5+
# This program is free software: you can redistribute it and/or modify
6+
# it under the terms of the GNU Affero Public License version 3 as
7+
# published by the Free Software Foundation.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
12+
# See the GNU Affero Public License for more details.
13+
#
14+
# You should have received a copy of the GNU Affero Public License
15+
# along with this program. If not, see http://www.gnu.org/licenses.
16+
# ----------------------------------------------------------------------

bindings/py/tests/regions/network_test.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import unittest
2020
import pytest
2121
import numpy as np
22+
import sys
23+
import os
24+
import pickle
2225

26+
2327
from htm.bindings.regions.PyRegion import PyRegion
2428
from htm.bindings.sdr import SDR
2529
import htm.bindings.engine_internal as engine
@@ -96,11 +100,12 @@ def setUp(self):
96100
engine.Network.cleanup()
97101
engine.Network.registerPyRegion(LinkRegion.__module__, LinkRegion.__name__)
98102

99-
@pytest.mark.skip(reason="pickle support needs work...another PR")
100103
def testSerializationWithPyRegion(self):
101104
"""Test (de)serialization of network containing a python region"""
102105
engine.Network.registerPyRegion(__name__,
103106
SerializationTestPyRegion.__name__)
107+
108+
file_path = "SerializationTest.stream"
104109
try:
105110
srcNet = engine.Network()
106111
srcNet.addRegion(SerializationTestPyRegion.__name__,
@@ -111,12 +116,20 @@ def testSerializationWithPyRegion(self):
111116
}))
112117

113118
# Serialize
114-
srcNet.saveToFile("SerializationTest.stream")
119+
# Note: This will do the following:
120+
# - Call network.saveToFile(), in C++. this opens the file.
121+
# - that calls network.save(stream)
122+
# - that will use Cereal to serialize the Network object.
123+
# - that will serialize the Region object.
124+
# - that will serialize PyBindRegion object because this is a python Region.
125+
# - that will use pickle to serialize SerializationTestPyRegion in
126+
# serialization_test_py_region.py into Base64.
127+
srcNet.saveToFile(file_path, engine.SerializableFormat.BINARY)
115128

116129

117130
# Deserialize
118131
destNet = engine.Network()
119-
destNet.loadFromFile("SerializationTest.stream")
132+
destNet.loadFromFile(file_path)
120133

121134
destRegion = destNet.getRegion(SerializationTestPyRegion.__name__)
122135

@@ -125,6 +138,8 @@ def testSerializationWithPyRegion(self):
125138

126139
finally:
127140
engine.Network.unregisterPyRegion(SerializationTestPyRegion.__name__)
141+
if os.path.isfile(file_path):
142+
os.unlink("SerializationTest.stream")
128143

129144

130145
def testSimpleTwoRegionNetworkIntrospection(self):
@@ -174,7 +189,6 @@ def testNetworkLinkTypeValidation(self):
174189
network.link("from", "to", "", "", "UInt32", "Real32")
175190

176191

177-
@pytest.mark.skip(reason="parameter types don't match.")
178192
def testParameters(self):
179193

180194
n = engine.Network()
@@ -333,4 +347,35 @@ def testExecuteCommand2(self):
333347
result = r.executeCommand("HelloWorld", 42, lst)
334348
self.assertTrue(result == "Hello World says: arg1=42 arg2=['list arg', 86]")
335349

350+
def testNetworkPickle(self):
351+
"""
352+
Test region pickling/unpickling.
353+
"""
354+
network = engine.Network()
355+
r_from = network.addRegion("from", "py.LinkRegion", "")
356+
r_to = network.addRegion("to", "py.LinkRegion", "")
357+
cnt = r_from.getOutputElementCount("UInt32")
358+
self.assertEqual(5, cnt)
359+
360+
network.link("from", "to", "", "", "UInt32", "UInt32")
361+
network.link("from", "to", "", "", "Real32", "Real32")
362+
network.link("from", "to", "", "", "Real32", "UInt32")
363+
network.link("from", "to", "", "", "UInt32", "Real32")
364+
network.initialize()
365+
366+
if sys.version_info[0] >= 3:
367+
proto = 3
368+
else:
369+
proto = 2
370+
371+
# Simple test: make sure that dumping / loading works...
372+
pickledNetwork = pickle.dumps(network, proto)
373+
network2 = pickle.loads(pickledNetwork)
374+
375+
s1 = network.getRegion("to").executeCommand("HelloWorld", "26", "64");
376+
s2 = network2.getRegion("to").executeCommand("HelloWorld", "26", "64");
377+
378+
self.assertEqual(s1,"Hello World says: arg1=26 arg2=64")
379+
self.assertEqual(s1, s2, "Simple Network pickle/unpickle failed.")
380+
336381

0 commit comments

Comments
 (0)