Skip to content

Commit 999c291

Browse files
authored
[Cherry-pick]Delete the function of saving layer object. (#34039)
* Save all the information of 'ParamBase' in 'Layer'. (#33500) * Save all the information of 'ParamBase' in 'Layer'. * edit unittest * delete the function of saving layer object. (#33697) * delete the function of saving layer object. * edit doc of paddle.save/load and polish error message
1 parent 0f266ac commit 999c291

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

python/paddle/fluid/tests/unittests/test_paddle_save_load.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -869,21 +869,11 @@ def test_save_load_layer(self):
869869
layer2 = LinearNet()
870870
layer1.eval()
871871
layer2.eval()
872+
origin_layer = (layer1, layer2)
872873
origin = (layer1(inps), layer2(inps))
873874
path = "test_save_load_layer_/layer.pdmodel"
874-
paddle.save((layer1, layer2), path)
875-
876-
# static
877-
paddle.enable_static()
878875
with self.assertRaises(ValueError):
879-
paddle.load(path)
880-
# dygraph
881-
paddle.disable_static()
882-
883-
loaded_layer = paddle.load(path)
884-
loaded_result = [l(inps) for l in loaded_layer]
885-
for i in range(len(origin)):
886-
self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10)
876+
paddle.save(origin_layer, path)
887877

888878

889879
if __name__ == '__main__':

python/paddle/framework/io.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _pickle_save(obj, f, protocol):
232232
raise ValueError("Expected 1<'protocol'<5, but received protocol={}".
233233
format(protocol))
234234

235-
def reudce_varbase(self):
235+
def reduce_varbase(self):
236236
data = self.numpy()
237237
name = self.name
238238

@@ -243,16 +243,32 @@ def reduce_LoDTensor(self):
243243

244244
return (eval, ('data', {'data': data}))
245245

246+
def reduce_Layer(self):
247+
raise ValueError(
248+
"paddle do not support saving `paddle.nn.Layer` object.")
249+
250+
dispatch_table_layer = dict()
251+
252+
def create_layer_dispatch_table(layer):
253+
dispatch_table_layer[layer.__class__] = reduce_Layer
254+
return layer
255+
256+
_parse_every_object(obj, lambda v: isinstance(v, core.Layer),
257+
create_layer_dispatch_table)
258+
246259
def add_dispatch_table():
247260
# This is not a good method, because the pickle module has been modified.
248-
pickle.dispatch_table[core.VarBase] = reudce_varbase
249-
pickle.dispatch_table[ParamBase] = reudce_varbase
261+
pickle.dispatch_table[core.VarBase] = reduce_varbase
262+
pickle.dispatch_table[ParamBase] = reduce_varbase
250263
pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor
264+
pickle.dispatch_table.update(dispatch_table_layer)
251265

252266
def pop_dispatch_table():
253267
pickle.dispatch_table.pop(core.VarBase)
254268
pickle.dispatch_table.pop(core.LoDTensor)
255269
pickle.dispatch_table.pop(ParamBase)
270+
for k in dispatch_table_layer:
271+
pickle.dispatch_table.pop(k)
256272

257273
# When value of dict is lager than 4GB ,there is a Bug on 'MAC python3'
258274
if sys.platform == 'darwin' and sys.version_info.major == 3:
@@ -272,10 +288,10 @@ def pop_dispatch_table():
272288
pickler = pickle.Pickler(f, protocol)
273289
pickler.dispatch_table = copyreg.dispatch_table.copy()
274290

275-
pickler.dispatch_table[core.VarBase] = reudce_varbase
291+
pickler.dispatch_table[core.VarBase] = reduce_varbase
276292
pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor
277-
pickler.dispatch_table[ParamBase] = reudce_varbase
278-
293+
pickler.dispatch_table[ParamBase] = reduce_varbase
294+
pickler.dispatch_table.update(dispatch_table_layer)
279295
pickler.dump(obj)
280296

281297

@@ -496,7 +512,7 @@ def save(obj, path, protocol=4, **configs):
496512
Save an object to the specified path.
497513
498514
.. note::
499-
Now supports saving ``state_dict`` of Layer/Optimizer, Layer, Tensor and nested structure containing Tensor, Program.
515+
Now supports saving ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program.
500516
501517
.. note::
502518
Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file,
@@ -690,7 +706,7 @@ def load(path, **configs):
690706
Load an object can be used in paddle from specified path.
691707
692708
.. note::
693-
Now supports loading ``state_dict`` of Layer/Optimizer, Layer, Tensor and nested structure containing Tensor, Program.
709+
Now supports loading ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program.
694710
695711
.. note::
696712
In order to use the model parameters saved by paddle more efficiently,

0 commit comments

Comments
 (0)