Skip to content

Commit c36a3f4

Browse files
committed
Add unittest for serialize/deserialize.
1 parent d34eb34 commit c36a3f4

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

python/paddle/v2/parameters.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ def __getitem__(self, key):
124124

125125
if len(self.__gradient_machines__) == 0:
126126
# create new parameter in python numpy.
127+
if len(self.__tmp_params__) != 0:
128+
ret_list = [
129+
mat for name, mat in self.__tmp_params__ if name == key
130+
]
131+
if len(ret_list) == 1:
132+
return ret_list[0]
127133
return np.ndarray(shape=shape, dtype=np.float32)
128134
else:
129135
for each_gradient_machine in self.__gradient_machines__:

python/paddle/v2/tests/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ cd $SCRIPTPATH
2222

2323
$1 -m pip install ../../../../paddle/dist/*.whl
2424

25-
test_list="test_data_feeder.py"
25+
test_list="test_data_feeder.py test_parameters.py"
2626

2727
export PYTHONPATH=$PWD/../../../../python/
2828

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import unittest
2+
import sys
3+
4+
try:
5+
import py_paddle
6+
7+
del py_paddle
8+
except ImportError:
9+
print >> sys.stderr, "It seems swig of Paddle is not installed, this " \
10+
"unittest will not be run."
11+
sys.exit(0)
12+
13+
import paddle.v2.parameters as parameters
14+
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
15+
import random
16+
import cStringIO
17+
import numpy
18+
19+
20+
def __rand_param_config__(name):
21+
conf = ParameterConfig()
22+
conf.name = name
23+
size = 1
24+
for i in xrange(2):
25+
dim = random.randint(1, 1000)
26+
conf.dims.append(dim)
27+
size *= dim
28+
conf.size = size
29+
assert conf.IsInitialized()
30+
return conf
31+
32+
33+
class TestParameters(unittest.TestCase):
34+
def test_serialization(self):
35+
params = parameters.Parameters()
36+
params.__append_config__(__rand_param_config__("param_0"))
37+
params.__append_config__(__rand_param_config__("param_1"))
38+
39+
for name in params.names():
40+
param = params.get(name)
41+
param[:] = numpy.random.uniform(
42+
-1.0, 1.0, size=params.get_shape(name))
43+
params.set(name, param)
44+
45+
tmp_file = cStringIO.StringIO()
46+
params.to_tar(tmp_file)
47+
tmp_file.seek(0)
48+
params_dup = parameters.Parameters.from_tar(tmp_file)
49+
50+
self.assertEqual(params_dup.names(), params.names())
51+
52+
for name in params.names():
53+
self.assertEqual(params.get_shape(name), params_dup.get_shape(name))
54+
p0 = params.get(name)
55+
p1 = params_dup.get(name)
56+
self.assertTrue(numpy.isclose(p0, p1).all())
57+
58+
59+
if __name__ == '__main__':
60+
unittest.main()

0 commit comments

Comments
 (0)