Skip to content

Commit cb5e0f1

Browse files
authored
Merge pull request #741 from htm-community/pickled_network
Fix for pickling a Network
2 parents 038f4d4 + 5345a59 commit cb5e0f1

File tree

9 files changed

+315
-84
lines changed

9 files changed

+315
-84
lines changed

bindings/py/cpp_src/bindings/engine/py_Engine.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,21 @@ namespace htm_ext
267267

268268
r.setInputData(name, s);
269269
});
270+
271+
py_Region.def(py::pickle(
272+
[](const Region& self) {
273+
std::stringstream ss;
274+
self.save(ss);
275+
return py::bytes(ss.str());
276+
},
277+
// Note: a de-serialized Region will need to be reattached to a Network
278+
// before it could be used. See Network::addRegion( Region*)
279+
[](const py::bytes& s) {
280+
std::istringstream ss(s);
281+
Region self;
282+
self.load(ss);
283+
return self;
284+
}));
270285

271286

272287
py_Region.def("getParameterInt32", &Region::getParameterInt32)
@@ -407,6 +422,11 @@ namespace htm_ext
407422
, py::arg("name")
408423
, py::arg("nodeType" )
409424
, py::arg("nodeParams"));
425+
py_Network.def("addRegion", (Region_Ptr_t (htm::Network::*)(
426+
Region_Ptr_t&))
427+
&htm::Network::addRegion,
428+
"add region for deserialization."
429+
, py::arg("region"));
410430

411431
py_Network.def("getRegions", &htm::Network::getRegions)
412432
.def("getRegion", &htm::Network::getRegion)
@@ -427,6 +447,18 @@ namespace htm_ext
427447
.def("saveToFile", &htm::Network::saveToFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY)
428448
.def("loadFromFile", &htm::Network::loadFromFile, py::arg("file"), py::arg("fmt") = SerializableFormat::BINARY);
429449

450+
py_Network.def(py::pickle(
451+
[](const Network& self) {
452+
std::stringstream ss;
453+
self.save(ss);
454+
return py::bytes(ss.str());
455+
},
456+
[](const py::bytes& s) {
457+
std::istringstream ss(s);
458+
Network self;
459+
self.load(ss);
460+
return self;
461+
}));
430462

431463
py_Network.def("link", &htm::Network::link
432464
, "Defines a link between regions"

bindings/py/cpp_src/plugin/PyBindRegion.cpp

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -304,30 +304,39 @@ namespace py = pybind11;
304304
// 1. serialize main state using pickle
305305
// 2. call class method to serialize external state
306306

307-
// 1. Serialize main state of the Python module
308-
// We want this to end up in the open stream obtained from bundle.
309-
// a. We first pickle the python into a temporary file.
310-
// b. copy the file into our open stream.
311-
312-
std::string tmp_pickle = "pickle.tmp";
313-
py::tuple args = py::make_tuple(tmp_pickle, "wb");
314-
auto f = py::module::import("__builtin__").attr("file")(*args);
307+
// Serialize main state of the Python module
308+
// We want this to end up in a string that we pass back to Cereal.
309+
// a. We first pickle the python into an in-memory byte stream.
310+
// b. We then convert that to a Base64 std::string that is returned.
311+
//
312+
// Basicly, we are executing the following Python code:
313+
// import io
314+
// import base64
315+
// import pickle
316+
// f = io.BytesIO()
317+
// pickle.dump(node, f, 3)
318+
// b = f.getvalue()
319+
// content = str(base64.b64encode(b))
320+
// f.close()
321+
322+
py::tuple args;
323+
auto f = py::module::import("io").attr("BytesIO")();
315324

316325
#if PY_MAJOR_VERSION >= 3
317326
auto pickle = py::module::import("pickle");
327+
args = py::make_tuple(node_, f, 3); // use type 3 protocol
318328
#else
319329
auto pickle = py::module::import("cPickle");
320-
#endif
321330
args = py::make_tuple(node_, f, 2); // use type 2 protocol
331+
#endif
322332
pickle.attr("dump")(*args);
323-
pickle.attr("close")();
324-
325-
// copy the pickle into the out string
326-
std::ifstream pfile(tmp_pickle.c_str(), std::ios::binary);
327-
std::string content((std::istreambuf_iterator<char>(pfile)),
328-
std::istreambuf_iterator<char>());
329-
pfile.close();
330-
Path::remove(tmp_pickle);
333+
334+
// copy the pickle stream into the content as a base64 encoded utf8 string
335+
py::bytes b = f.attr("getvalue")();
336+
args = py::make_tuple(b);
337+
std::string content = py::str(py::module::import("base64").attr("b64encode")(*args));
338+
339+
f.attr("close")();
331340
return content;
332341
}
333342
std::string PyBindRegion::extraSerialize() const
@@ -354,31 +363,34 @@ namespace py = pybind11;
354363
void PyBindRegion::pickleDeserialize(std::string p) {
355364
// 1. deserialize main state using pickle
356365
// 2. call class method to deserialize external state
357-
358-
std::ofstream des;
359-
std::string tmp_pickle = "pickle.tmp";
360-
361-
362-
std::ofstream pfile(tmp_pickle.c_str(), std::ios::binary);
363-
pfile.write(p.c_str(), p.size());
364-
pfile.close();
365-
366-
367-
// Tell Python to un-pickle using what is now in the pickle.tmp file.
368-
py::args args = py::make_tuple(tmp_pickle, "rb");
369-
auto f = py::module::import("__builtin__").attr("file")(*args);
366+
//
367+
// Tell Python to un-pickle using what is in the string p.
368+
// but first we need to convert the base64 string into bytes.
369+
//
370+
// Basically we are executing the following Python code:
371+
// import base64
372+
// import io
373+
// import pickle
374+
// b = base64.b64decode(bytes(p))
375+
// f = io.BytesIO(b)
376+
// node = pickle.load(f)
377+
// f.close()
378+
379+
py::args args;
380+
args = py::make_tuple(py::bytes(p));
381+
py::bytes b = py::module::import("base64").attr("b64decode")(*args);
382+
args = py::make_tuple(b);
383+
auto f = py::module::import("io").attr("BytesIO")(*args);
370384

371385
#if PY_MAJOR_VERSION >= 3
372386
auto pickle = py::module::import("pickle");
373387
#else
374388
auto pickle = py::module::import("cPickle");
375389
#endif
390+
args = py::make_tuple(f);
391+
node_ = pickle.attr("load")(*args);
376392

377-
args = py::make_tuple(node_, f);
378-
pickle.attr("load")(*args);
379-
380-
pickle.attr("close")();
381-
Path::remove(tmp_pickle);
393+
f.attr("close")();
382394
}
383395

384396
void PyBindRegion::extraDeserialize(std::string e) {

bindings/py/cpp_src/plugin/RegisteredRegionImplPy.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ namespace htm
103103
{
104104
throw Exception(__FILE__, __LINE__, e.what());
105105
}
106+
catch (const py::cast_error& e)
107+
{
108+
throw Exception(__FILE__, __LINE__, e.what());
109+
}
106110
catch (htm::Exception & e)
107111
{
108112
throw htm::Exception(e);
@@ -127,6 +131,10 @@ namespace htm
127131
{
128132
throw Exception(__FILE__, __LINE__, e.what());
129133
}
134+
catch (const py::cast_error& e)
135+
{
136+
throw Exception(__FILE__, __LINE__, e.what());
137+
}
130138
catch (htm::Exception & e)
131139
{
132140
throw htm::Exception(e);
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# ----------------------------------------------------------------------
2+
# HTM Community Edition of NuPIC
3+
# Copyright (C) 2015, Numenta, Inc.
4+
#
5+
# This program is free software: you can redistribute it and/or modify
6+
# it under the terms of the GNU Affero Public License version 3 as
7+
# published by the Free Software Foundation.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
12+
# See the GNU Affero Public License for more details.
13+
#
14+
# You should have received a copy of the GNU Affero Public License
15+
# along with this program. If not, see http://www.gnu.org/licenses.
16+
# ----------------------------------------------------------------------

bindings/py/tests/regions/network_test.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import unittest
2020
import pytest
2121
import numpy as np
22+
import sys
23+
import os
24+
import pickle
2225

26+
2327
from htm.bindings.regions.PyRegion import PyRegion
2428
from htm.bindings.sdr import SDR
2529
import htm.bindings.engine_internal as engine
@@ -96,11 +100,12 @@ def setUp(self):
96100
engine.Network.cleanup()
97101
engine.Network.registerPyRegion(LinkRegion.__module__, LinkRegion.__name__)
98102

99-
@pytest.mark.skip(reason="pickle support needs work...another PR")
100103
def testSerializationWithPyRegion(self):
101104
"""Test (de)serialization of network containing a python region"""
102105
engine.Network.registerPyRegion(__name__,
103106
SerializationTestPyRegion.__name__)
107+
108+
file_path = "SerializationTest.stream"
104109
try:
105110
srcNet = engine.Network()
106111
srcNet.addRegion(SerializationTestPyRegion.__name__,
@@ -111,12 +116,20 @@ def testSerializationWithPyRegion(self):
111116
}))
112117

113118
# Serialize
114-
srcNet.saveToFile("SerializationTest.stream")
119+
# Note: This will do the following:
120+
# - Call network.saveToFile(), in C++. this opens the file.
121+
# - that calls network.save(stream)
122+
# - that will use Cereal to serialize the Network object.
123+
# - that will serialize the Region object.
124+
# - that will serialize PyBindRegion object because this is a python Region.
125+
# - that will use pickle to serialize SerializationTestPyRegion in
126+
# serialization_test_py_region.py into Base64.
127+
srcNet.saveToFile(file_path, engine.SerializableFormat.BINARY)
115128

116129

117130
# Deserialize
118131
destNet = engine.Network()
119-
destNet.loadFromFile("SerializationTest.stream")
132+
destNet.loadFromFile(file_path)
120133

121134
destRegion = destNet.getRegion(SerializationTestPyRegion.__name__)
122135

@@ -125,6 +138,8 @@ def testSerializationWithPyRegion(self):
125138

126139
finally:
127140
engine.Network.unregisterPyRegion(SerializationTestPyRegion.__name__)
141+
if os.path.isfile(file_path):
142+
os.unlink("SerializationTest.stream")
128143

129144

130145
def testSimpleTwoRegionNetworkIntrospection(self):
@@ -174,7 +189,6 @@ def testNetworkLinkTypeValidation(self):
174189
network.link("from", "to", "", "", "UInt32", "Real32")
175190

176191

177-
@pytest.mark.skip(reason="parameter types don't match.")
178192
def testParameters(self):
179193

180194
n = engine.Network()
@@ -331,4 +345,35 @@ def testExecuteCommand2(self):
331345
result = r.executeCommand("HelloWorld", 42, lst)
332346
self.assertTrue(result == "Hello World says: arg1=42 arg2=['list arg', 86]")
333347

348+
def testNetworkPickle(self):
349+
"""
350+
Test region pickling/unpickling.
351+
"""
352+
network = engine.Network()
353+
r_from = network.addRegion("from", "py.LinkRegion", "")
354+
r_to = network.addRegion("to", "py.LinkRegion", "")
355+
cnt = r_from.getOutputElementCount("UInt32")
356+
self.assertEqual(5, cnt)
357+
358+
network.link("from", "to", "", "", "UInt32", "UInt32")
359+
network.link("from", "to", "", "", "Real32", "Real32")
360+
network.link("from", "to", "", "", "Real32", "UInt32")
361+
network.link("from", "to", "", "", "UInt32", "Real32")
362+
network.initialize()
363+
364+
if sys.version_info[0] >= 3:
365+
proto = 3
366+
else:
367+
proto = 2
368+
369+
# Simple test: make sure that dumping / loading works...
370+
pickledNetwork = pickle.dumps(network, proto)
371+
network2 = pickle.loads(pickledNetwork)
372+
373+
s1 = network.getRegion("to").executeCommand("HelloWorld", "26", "64");
374+
s2 = network2.getRegion("to").executeCommand("HelloWorld", "26", "64");
375+
376+
self.assertEqual(s1,"Hello World says: arg1=26 arg2=64")
377+
self.assertEqual(s1, s2, "Simple Network pickle/unpickle failed.")
378+
334379

bindings/py/tests/regions/pyregion_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# ----------------------------------------------------------------------
1717

1818
import unittest
19+
import sys
1920

2021
from htm.bindings.regions.PyRegion import PyRegion
2122

@@ -104,6 +105,25 @@ def testCallUnimplementedMethod(self):
104105
self.assertEqual(str(cw.exception),
105106
"The method setParameter is not implemented.")
106107

108+
def testPickle(self):
109+
"""
110+
Test region pickling/unpickling.
111+
"""
112+
y = Y()
113+
114+
if sys.version_info[0] >= 3:
115+
import pickle
116+
proto = 3
117+
else:
118+
import cpickle as pickle
119+
proto = 2
120+
121+
# Simple test: make sure that dumping / loading works...
122+
pickledRegion = pickle.dumps(y, proto)
123+
y2 = pickle.loads(pickledRegion)
124+
self.assertEqual(y.zzz, y2.zzz, "Simple Region pickle/unpickle failed.")
125+
126+
107127

108128

109129
if __name__ == "__main__":

src/htm/engine/Network.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ Network::Network() {
4444
commonInit();
4545
}
4646

47+
// move constructor
48+
Network::Network(Network && n) {
49+
regions_ = std::move(n.regions_);
50+
minEnabledPhase_ = n.minEnabledPhase_;
51+
maxEnabledPhase_ = n.maxEnabledPhase_;
52+
phaseInfo_ = std::move(n.phaseInfo_);
53+
callbacks_ = n.callbacks_;
54+
iteration_ = n.iteration_;
55+
}
56+
4757
Network::Network(const std::string& filename) {
4858
commonInit();
4959
loadFromFile(filename);

src/htm/engine/Network.hpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,20 @@ class Link;
6060
*
6161
* Create an new Network
6262
*
63-
* @Note if the Network object gets copied it does not do a
64-
* deep copy. So both copies point to the same set of
65-
* regions and links. The last Network object to go
66-
* out-of-scope will delete the regions and links.
6763
*/
6864
Network();
6965
Network(const std::string& filename);
7066

71-
/*
72-
* @Note: the pickle functions in the python bindings
73-
* require that the Network object be copyable.
74-
* The default copy constructor is ok.
67+
/**
68+
* Cannot copy or assign a Network object. But can be moved.
7569
*/
70+
Network(Network&&); // move is allowed
71+
Network(const Network&) = delete;
72+
void operator=(const Network&) = delete;
7673

7774
/**
7875
* Destructor.
7976
*
80-
* Destruct the network and unregister it from NuPIC:
81-
*
82-
* - Uninitialize all regions
83-
* - Remove all links
84-
* - Delete the regions themselves
85-
*
86-
* @todo Should we document the tear down steps above?
8777
*/
8878
~Network();
8979

0 commit comments

Comments
 (0)