|
1 | 1 | import numpy as np
|
2 | 2 | import py_paddle.swig_paddle as api
|
3 | 3 | from paddle.proto.ParameterConfig_pb2 import ParameterConfig
|
4 |
| - |
| 4 | +import struct |
| 5 | +import tarfile |
| 6 | +import cStringIO |
5 | 7 | from topology import Topology
|
6 | 8 |
|
7 | 9 | __all__ = ['Parameters', 'create']
|
@@ -122,6 +124,12 @@ def __getitem__(self, key):
|
122 | 124 |
|
123 | 125 | if len(self.__gradient_machines__) == 0:
|
124 | 126 | # create new parameter in python numpy.
|
| 127 | + if len(self.__tmp_params__) != 0: |
| 128 | + ret_list = [ |
| 129 | + mat for name, mat in self.__tmp_params__ if name == key |
| 130 | + ] |
| 131 | + if len(ret_list) == 1: |
| 132 | + return ret_list[0] |
125 | 133 | return np.ndarray(shape=shape, dtype=np.float32)
|
126 | 134 | else:
|
127 | 135 | for each_gradient_machine in self.__gradient_machines__:
|
@@ -228,6 +236,67 @@ def append_gradient_machine(self, gradient_machine):
|
228 | 236 |
|
229 | 237 | self.__gradient_machines__.append(gradient_machine)
|
230 | 238 |
|
| 239 | + def serialize(self, name, f): |
| 240 | + """ |
| 241 | +
|
| 242 | + :param name: |
| 243 | + :param f: |
| 244 | + :type f: file |
| 245 | + :return: |
| 246 | + """ |
| 247 | + param = self.get(name) |
| 248 | + size = reduce(lambda a, b: a * b, param.shape) |
| 249 | + f.write(struct.pack("IIQ", 0, 4, size)) |
| 250 | + param = param.astype(np.float32) |
| 251 | + f.write(param.tobytes()) |
| 252 | + |
| 253 | + def deserialize(self, name, f): |
| 254 | + """ |
| 255 | +
|
| 256 | + :param name: |
| 257 | + :param f: |
| 258 | + :type f: file |
| 259 | + :return: |
| 260 | + """ |
| 261 | + f.read(16) # header |
| 262 | + arr = np.frombuffer(f.read(), dtype=np.float32) |
| 263 | + self.set(name, arr.reshape(self.get_shape(name))) |
| 264 | + |
| 265 | + def to_tar(self, f): |
| 266 | + tar = tarfile.TarFile(fileobj=f, mode='w') |
| 267 | + for nm in self.names(): |
| 268 | + buf = cStringIO.StringIO() |
| 269 | + self.serialize(nm, buf) |
| 270 | + tarinfo = tarfile.TarInfo(name=nm) |
| 271 | + buf.seek(0) |
| 272 | + tarinfo.size = len(buf.getvalue()) |
| 273 | + tar.addfile(tarinfo, buf) |
| 274 | + |
| 275 | + conf = self.__param_conf__[nm] |
| 276 | + confStr = conf.SerializeToString() |
| 277 | + tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm) |
| 278 | + tarinfo.size = len(confStr) |
| 279 | + buf = cStringIO.StringIO(confStr) |
| 280 | + buf.seek(0) |
| 281 | + tar.addfile(tarinfo, fileobj=buf) |
| 282 | + |
| 283 | + @staticmethod |
| 284 | + def from_tar(f): |
| 285 | + params = Parameters() |
| 286 | + tar = tarfile.TarFile(fileobj=f, mode='r') |
| 287 | + for finfo in tar: |
| 288 | + assert isinstance(finfo, tarfile.TarInfo) |
| 289 | + if finfo.name.endswith('.protobuf'): |
| 290 | + f = tar.extractfile(finfo) |
| 291 | + conf = ParameterConfig() |
| 292 | + conf.ParseFromString(f.read()) |
| 293 | + params.__append_config__(conf) |
| 294 | + |
| 295 | + for param_name in params.names(): |
| 296 | + f = tar.extractfile(param_name) |
| 297 | + params.deserialize(param_name, f) |
| 298 | + return params |
| 299 | + |
231 | 300 |
|
232 | 301 | def __get_parameter_in_gradient_machine__(gradient_machine, name):
|
233 | 302 | """
|
|
0 commit comments