17
17
import os
18
18
import collections
19
19
from .. import core
20
- from ..framework import Variable , Parameter , default_main_program
21
- from .layers import Layer
20
+ from ..framework import Variable , default_main_program
22
21
23
22
__all__ = ['save_persistables' , 'load_persistables' ]
24
23
25
24
26
- def save_persistables (obj , dirname , filename = None ):
25
+ def save_persistables (vardict , dirname , filename = None ):
27
26
"""
28
27
This function filters out all variables in layer.parameters from the
29
28
give `layer` and then trys to load these variables from the folder
@@ -35,7 +34,7 @@ def save_persistables(obj, dirname, filename=None):
35
34
the file name.
36
35
37
36
Args:
38
- var_list (dict of Parameters|Layer ): The parameters will
37
+ vardict (dict of Parameters): The parameters will
39
38
be saved. If it is None, nothing
40
39
will be deal.
41
40
dirname(str): The directory path.
@@ -69,17 +68,14 @@ def save_persistables(obj, dirname, filename=None):
69
68
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
70
69
init_cell)
71
70
param_path = "./my_paddle_model"
72
- fluid.imperative.checkpoint.save_persistables(ptb_model.parameters (), dirname=param_path,
71
+ fluid.imperative.checkpoint.save_persistables(ptb_model.state_dict (), dirname=param_path,
73
72
layer=ptb_model)
74
73
"""
75
- if isinstance (obj , collections .OrderedDict ):
76
- _save_var_to_file (obj , dirname , filename )
77
- elif isinstance (obj , Layer ):
78
- _save_var_to_file (
79
- obj .state_dict (include_sublayers = True ), dirname , filename )
74
+ if isinstance (vardict , collections .OrderedDict ):
75
+ _save_var_to_file (vardict , dirname , filename )
80
76
81
77
82
- def load_persistables (obj , dirname , filename = None ):
78
+ def load_persistables (vardict , dirname , filename = None ):
83
79
"""
84
80
This function trys to load persistable variables from the folder
85
81
`dirname` or the file `filename`.
@@ -90,7 +86,7 @@ def load_persistables(obj, dirname, filename=None):
90
86
the file name.
91
87
92
88
Args:
93
- obj (dict of Parameters|Layer ): The parameters will be loaded.
89
+ vardict (dict of Parameters): The parameters will be loaded.
94
90
dirname(str): The directory path.
95
91
filename(str|None): The file which saved all variables, this file path should be end with '.npz'. If variables were
96
92
saved in differnet files, set it to None.
@@ -111,16 +107,13 @@ def load_persistables(obj, dirname, filename=None):
111
107
my_layer = layer(fluid.imperative.Layer)
112
108
param_path = "./my_paddle_model"
113
109
filename = "model.file"
114
- param_dict = fluid.imperative.checkpoint.load_persistables(my_layer, var_list , param_path,
110
+ param_dict = fluid.imperative.checkpoint.load_persistables(my_layer.state_dict() , param_path,
115
111
filename=filename)
116
112
param_1 = param_dict['PtbModel_0.w_1']
117
113
118
114
"""
119
- if isinstance (obj , collections .OrderedDict ):
120
- return _load_var_from_file (obj , dirname , filename )
121
- elif isinstance (obj , Layer ):
122
- return _load_var_from_file (
123
- obj .state_dict (include_sublayers = True ), dirname , filename )
115
+ if isinstance (vardict , collections .OrderedDict ):
116
+ return _load_var_from_file (vardict , dirname , filename )
124
117
125
118
return {}
126
119
0 commit comments