1919import unittest
2020import pytest
2121import numpy as np
22+ import sys
23+ import os
24+ import pickle
2225
26+
2327from htm .bindings .regions .PyRegion import PyRegion
2428from htm .bindings .sdr import SDR
2529import htm .bindings .engine_internal as engine
@@ -96,11 +100,12 @@ def setUp(self):
96100 engine .Network .cleanup ()
97101 engine .Network .registerPyRegion (LinkRegion .__module__ , LinkRegion .__name__ )
98102
99- @pytest .mark .skip (reason = "pickle support needs work...another PR" )
100103 def testSerializationWithPyRegion (self ):
101104 """Test (de)serialization of network containing a python region"""
102105 engine .Network .registerPyRegion (__name__ ,
103106 SerializationTestPyRegion .__name__ )
107+
108+ file_path = "SerializationTest.stream"
104109 try :
105110 srcNet = engine .Network ()
106111 srcNet .addRegion (SerializationTestPyRegion .__name__ ,
@@ -111,12 +116,20 @@ def testSerializationWithPyRegion(self):
111116 }))
112117
113118 # Serialize
114- srcNet .saveToFile ("SerializationTest.stream" )
119+ # Note: This will do the following:
120+ # - Call network.saveToFile(), in C++. this opens the file.
121+ # - that calls network.save(stream)
122+ # - that will use Cereal to serialize the Network object.
123+ # - that will serialize the Region object.
124+ # - that will serialize PyBindRegion object because this is a python Region.
125+ # - that will use pickle to serialize SerializationTestPyRegion in
126+ # serialization_test_py_region.py into Base64.
127+ srcNet .saveToFile (file_path , engine .SerializableFormat .BINARY )
115128
116129
117130 # Deserialize
118131 destNet = engine .Network ()
119- destNet .loadFromFile ("SerializationTest.stream" )
132+ destNet .loadFromFile (file_path )
120133
121134 destRegion = destNet .getRegion (SerializationTestPyRegion .__name__ )
122135
@@ -125,6 +138,8 @@ def testSerializationWithPyRegion(self):
125138
126139 finally :
127140 engine .Network .unregisterPyRegion (SerializationTestPyRegion .__name__ )
141+ if os .path .isfile (file_path ):
142+ os .unlink ("SerializationTest.stream" )
128143
129144
130145 def testSimpleTwoRegionNetworkIntrospection (self ):
@@ -174,7 +189,6 @@ def testNetworkLinkTypeValidation(self):
174189 network .link ("from" , "to" , "" , "" , "UInt32" , "Real32" )
175190
176191
177- @pytest .mark .skip (reason = "parameter types don't match." )
178192 def testParameters (self ):
179193
180194 n = engine .Network ()
@@ -333,4 +347,35 @@ def testExecuteCommand2(self):
333347 result = r .executeCommand ("HelloWorld" , 42 , lst )
334348 self .assertTrue (result == "Hello World says: arg1=42 arg2=['list arg', 86]" )
335349
350+ def testNetworkPickle (self ):
351+ """
352+ Test region pickling/unpickling.
353+ """
354+ network = engine .Network ()
355+ r_from = network .addRegion ("from" , "py.LinkRegion" , "" )
356+ r_to = network .addRegion ("to" , "py.LinkRegion" , "" )
357+ cnt = r_from .getOutputElementCount ("UInt32" )
358+ self .assertEqual (5 , cnt )
359+
360+ network .link ("from" , "to" , "" , "" , "UInt32" , "UInt32" )
361+ network .link ("from" , "to" , "" , "" , "Real32" , "Real32" )
362+ network .link ("from" , "to" , "" , "" , "Real32" , "UInt32" )
363+ network .link ("from" , "to" , "" , "" , "UInt32" , "Real32" )
364+ network .initialize ()
365+
366+ if sys .version_info [0 ] >= 3 :
367+ proto = 3
368+ else :
369+ proto = 2
370+
371+ # Simple test: make sure that dumping / loading works...
372+ pickledNetwork = pickle .dumps (network , proto )
373+ network2 = pickle .loads (pickledNetwork )
374+
375+ s1 = network .getRegion ("to" ).executeCommand ("HelloWorld" , "26" , "64" );
376+ s2 = network2 .getRegion ("to" ).executeCommand ("HelloWorld" , "26" , "64" );
377+
378+ self .assertEqual (s1 ,"Hello World says: arg1=26 arg2=64" )
379+ self .assertEqual (s1 , s2 , "Simple Network pickle/unpickle failed." )
380+
336381
0 commit comments