1919import unittest
2020import pytest
2121import numpy as np
22+ import sys
23+ import os
24+ import pickle
2225
26+
2327from htm .bindings .regions .PyRegion import PyRegion
2428from htm .bindings .sdr import SDR
2529import htm .bindings .engine_internal as engine
@@ -96,11 +100,12 @@ def setUp(self):
96100 engine .Network .cleanup ()
97101 engine .Network .registerPyRegion (LinkRegion .__module__ , LinkRegion .__name__ )
98102
99- @pytest .mark .skip (reason = "pickle support needs work...another PR" )
100103 def testSerializationWithPyRegion (self ):
101104 """Test (de)serialization of network containing a python region"""
102105 engine .Network .registerPyRegion (__name__ ,
103106 SerializationTestPyRegion .__name__ )
107+
108+ file_path = "SerializationTest.stream"
104109 try :
105110 srcNet = engine .Network ()
106111 srcNet .addRegion (SerializationTestPyRegion .__name__ ,
@@ -111,12 +116,20 @@ def testSerializationWithPyRegion(self):
111116 }))
112117
113118 # Serialize
114- srcNet .saveToFile ("SerializationTest.stream" )
119+ # Note: This will do the following:
120+ # - Call network.saveToFile(), in C++. this opens the file.
121+ # - that calls network.save(stream)
122+ # - that will use Cereal to serialize the Network object.
123+ # - that will serialize the Region object.
124+ # - that will serialize PyBindRegion object because this is a python Region.
125+ # - that will use pickle to serialize SerializationTestPyRegion in
126+ # serialization_test_py_region.py into Base64.
127+ srcNet .saveToFile (file_path , engine .SerializableFormat .BINARY )
115128
116129
117130 # Deserialize
118131 destNet = engine .Network ()
119- destNet .loadFromFile ("SerializationTest.stream" )
132+ destNet .loadFromFile (file_path )
120133
121134 destRegion = destNet .getRegion (SerializationTestPyRegion .__name__ )
122135
@@ -125,6 +138,8 @@ def testSerializationWithPyRegion(self):
125138
126139 finally :
127140 engine .Network .unregisterPyRegion (SerializationTestPyRegion .__name__ )
141+ if os .path .isfile (file_path ):
142+ os .unlink ("SerializationTest.stream" )
128143
129144
130145 def testSimpleTwoRegionNetworkIntrospection (self ):
@@ -174,7 +189,6 @@ def testNetworkLinkTypeValidation(self):
174189 network .link ("from" , "to" , "" , "" , "UInt32" , "Real32" )
175190
176191
177- @pytest .mark .skip (reason = "parameter types don't match." )
178192 def testParameters (self ):
179193
180194 n = engine .Network ()
@@ -331,4 +345,35 @@ def testExecuteCommand2(self):
331345 result = r .executeCommand ("HelloWorld" , 42 , lst )
332346 self .assertTrue (result == "Hello World says: arg1=42 arg2=['list arg', 86]" )
333347
348+ def testNetworkPickle (self ):
349+ """
350+ Test region pickling/unpickling.
351+ """
352+ network = engine .Network ()
353+ r_from = network .addRegion ("from" , "py.LinkRegion" , "" )
354+ r_to = network .addRegion ("to" , "py.LinkRegion" , "" )
355+ cnt = r_from .getOutputElementCount ("UInt32" )
356+ self .assertEqual (5 , cnt )
357+
358+ network .link ("from" , "to" , "" , "" , "UInt32" , "UInt32" )
359+ network .link ("from" , "to" , "" , "" , "Real32" , "Real32" )
360+ network .link ("from" , "to" , "" , "" , "Real32" , "UInt32" )
361+ network .link ("from" , "to" , "" , "" , "UInt32" , "Real32" )
362+ network .initialize ()
363+
364+ if sys .version_info [0 ] >= 3 :
365+ proto = 3
366+ else :
367+ proto = 2
368+
369+ # Simple test: make sure that dumping / loading works...
370+ pickledNetwork = pickle .dumps (network , proto )
371+ network2 = pickle .loads (pickledNetwork )
372+
373+ s1 = network .getRegion ("to" ).executeCommand ("HelloWorld" , "26" , "64" );
374+ s2 = network2 .getRegion ("to" ).executeCommand ("HelloWorld" , "26" , "64" );
375+
376+ self .assertEqual (s1 ,"Hello World says: arg1=26 arg2=64" )
377+ self .assertEqual (s1 , s2 , "Simple Network pickle/unpickle failed." )
378+
334379
0 commit comments