Skip to content

Commit 963bd5d

Browse files
authored
Merge pull request #1520 from reyoung/feature/serialize_deserialize_in_parameters
Add save/load parameters.
2 parents 5f2cbce + c36a3f4 commit 963bd5d

File tree

6 files changed

+153
-8
lines changed

6 files changed

+153
-8
lines changed

demo/mnist/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ plot.png
55
train.log
66
*pyc
77
.ipynb_checkpoints
8+
params.pkl
9+
params.tar
10+
params.tar.gz

demo/mnist/api_train_v2.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import paddle.v2 as paddle
2+
import gzip
23

34

45
def softmax_regression(img):
@@ -71,7 +72,11 @@ def main():
7172

7273
cost = paddle.layer.classification_cost(input=predict, label=label)
7374

74-
parameters = paddle.parameters.create(cost)
75+
try:
76+
with gzip.open('params.tar.gz', 'r') as f:
77+
parameters = paddle.parameters.Parameters.from_tar(f)
78+
except IOError:
79+
parameters = paddle.parameters.create(cost)
7580

7681
optimizer = paddle.optimizer.Momentum(
7782
learning_rate=0.1 / 128.0,
@@ -86,10 +91,18 @@ def main():
8691

8792
def event_handler(event):
8893
if isinstance(event, paddle.event.EndIteration):
89-
if event.batch_id % 100 == 0:
90-
print "Pass %d, Batch %d, Cost %f, %s" % (
91-
event.pass_id, event.batch_id, event.cost, event.metrics)
92-
if isinstance(event, paddle.event.EndPass):
94+
if event.batch_id % 1000 == 0:
95+
result = trainer.test(reader=paddle.reader.batched(
96+
paddle.dataset.mnist.test(), batch_size=256))
97+
98+
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
99+
event.pass_id, event.batch_id, event.cost, event.metrics,
100+
result.metrics)
101+
102+
with gzip.open('params.tar.gz', 'w') as f:
103+
parameters.to_tar(f)
104+
105+
elif isinstance(event, paddle.event.EndPass):
93106
result = trainer.test(reader=paddle.reader.batched(
94107
paddle.dataset.mnist.test(), batch_size=128))
95108
print "Test with Pass %d, Cost %f, %s\n" % (

python/paddle/v2/parameters.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
22
import py_paddle.swig_paddle as api
33
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
4-
4+
import struct
5+
import tarfile
6+
import cStringIO
57
from topology import Topology
68

79
__all__ = ['Parameters', 'create']
@@ -122,6 +124,12 @@ def __getitem__(self, key):
122124

123125
if len(self.__gradient_machines__) == 0:
124126
# create new parameter in python numpy.
127+
if len(self.__tmp_params__) != 0:
128+
ret_list = [
129+
mat for name, mat in self.__tmp_params__ if name == key
130+
]
131+
if len(ret_list) == 1:
132+
return ret_list[0]
125133
return np.ndarray(shape=shape, dtype=np.float32)
126134
else:
127135
for each_gradient_machine in self.__gradient_machines__:
@@ -228,6 +236,67 @@ def append_gradient_machine(self, gradient_machine):
228236

229237
self.__gradient_machines__.append(gradient_machine)
230238

239+
def serialize(self, name, f):
240+
"""
241+
242+
:param name:
243+
:param f:
244+
:type f: file
245+
:return:
246+
"""
247+
param = self.get(name)
248+
size = reduce(lambda a, b: a * b, param.shape)
249+
f.write(struct.pack("IIQ", 0, 4, size))
250+
param = param.astype(np.float32)
251+
f.write(param.tobytes())
252+
253+
def deserialize(self, name, f):
254+
"""
255+
256+
:param name:
257+
:param f:
258+
:type f: file
259+
:return:
260+
"""
261+
f.read(16) # header
262+
arr = np.frombuffer(f.read(), dtype=np.float32)
263+
self.set(name, arr.reshape(self.get_shape(name)))
264+
265+
def to_tar(self, f):
266+
tar = tarfile.TarFile(fileobj=f, mode='w')
267+
for nm in self.names():
268+
buf = cStringIO.StringIO()
269+
self.serialize(nm, buf)
270+
tarinfo = tarfile.TarInfo(name=nm)
271+
buf.seek(0)
272+
tarinfo.size = len(buf.getvalue())
273+
tar.addfile(tarinfo, buf)
274+
275+
conf = self.__param_conf__[nm]
276+
confStr = conf.SerializeToString()
277+
tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm)
278+
tarinfo.size = len(confStr)
279+
buf = cStringIO.StringIO(confStr)
280+
buf.seek(0)
281+
tar.addfile(tarinfo, fileobj=buf)
282+
283+
@staticmethod
284+
def from_tar(f):
285+
params = Parameters()
286+
tar = tarfile.TarFile(fileobj=f, mode='r')
287+
for finfo in tar:
288+
assert isinstance(finfo, tarfile.TarInfo)
289+
if finfo.name.endswith('.protobuf'):
290+
f = tar.extractfile(finfo)
291+
conf = ParameterConfig()
292+
conf.ParseFromString(f.read())
293+
params.__append_config__(conf)
294+
295+
for param_name in params.names():
296+
f = tar.extractfile(param_name)
297+
params.deserialize(param_name, f)
298+
return params
299+
231300

232301
def __get_parameter_in_gradient_machine__(gradient_machine, name):
233302
"""

python/paddle/v2/tests/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ cd $SCRIPTPATH
2222

2323
$1 -m pip install ../../../../paddle/dist/*.whl
2424

25-
test_list="test_data_feeder.py"
25+
test_list="test_data_feeder.py test_parameters.py"
2626

2727
export PYTHONPATH=$PWD/../../../../python/
2828

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import unittest
2+
import sys
3+
4+
try:
5+
import py_paddle
6+
7+
del py_paddle
8+
except ImportError:
9+
print >> sys.stderr, "It seems swig of Paddle is not installed, this " \
10+
"unittest will not be run."
11+
sys.exit(0)
12+
13+
import paddle.v2.parameters as parameters
14+
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
15+
import random
16+
import cStringIO
17+
import numpy
18+
19+
20+
def __rand_param_config__(name):
21+
conf = ParameterConfig()
22+
conf.name = name
23+
size = 1
24+
for i in xrange(2):
25+
dim = random.randint(1, 1000)
26+
conf.dims.append(dim)
27+
size *= dim
28+
conf.size = size
29+
assert conf.IsInitialized()
30+
return conf
31+
32+
33+
class TestParameters(unittest.TestCase):
34+
def test_serialization(self):
35+
params = parameters.Parameters()
36+
params.__append_config__(__rand_param_config__("param_0"))
37+
params.__append_config__(__rand_param_config__("param_1"))
38+
39+
for name in params.names():
40+
param = params.get(name)
41+
param[:] = numpy.random.uniform(
42+
-1.0, 1.0, size=params.get_shape(name))
43+
params.set(name, param)
44+
45+
tmp_file = cStringIO.StringIO()
46+
params.to_tar(tmp_file)
47+
tmp_file.seek(0)
48+
params_dup = parameters.Parameters.from_tar(tmp_file)
49+
50+
self.assertEqual(params_dup.names(), params.names())
51+
52+
for name in params.names():
53+
self.assertEqual(params.get_shape(name), params_dup.get_shape(name))
54+
p0 = params.get(name)
55+
p1 = params_dup.get(name)
56+
self.assertTrue(numpy.isclose(p0, p1).all())
57+
58+
59+
if __name__ == '__main__':
60+
unittest.main()

python/paddle/v2/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def __init__(self, cost, parameters, update_equation):
5757
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
5858
self.__optimizer__.enable_types())
5959
assert isinstance(gm, api.GradientMachine)
60-
parameters.append_gradient_machine(gm)
6160
self.__gradient_machine__ = gm
6261
self.__gradient_machine__.randParameters()
62+
parameters.append_gradient_machine(gm)
6363

6464
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
6565
"""

0 commit comments

Comments
 (0)